diff --git a/requirements/requirements-codestyle.txt b/requirements/requirements-codestyle.txt
index 2b7bad436..fcdfb54e8 100644
--- a/requirements/requirements-codestyle.txt
+++ b/requirements/requirements-codestyle.txt
@@ -4,7 +4,7 @@ flake8-tidy-imports==1.1.0
pycodestyle==2.3.1
# Sort and lint imports
-isort==4.3.3
+isort==4.3.17
# black
black==19.3b0
\ No newline at end of file
diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py
index 55c06982d..cd22cd368 100644
--- a/rest_framework/__init__.py
+++ b/rest_framework/__init__.py
@@ -7,22 +7,22 @@ ______ _____ _____ _____ __
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""
-__title__ = 'Django REST framework'
-__version__ = '3.9.2'
-__author__ = 'Tom Christie'
-__license__ = 'BSD 2-Clause'
-__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
+__title__ = "Django REST framework"
+__version__ = "3.9.2"
+__author__ = "Tom Christie"
+__license__ = "BSD 2-Clause"
+__copyright__ = "Copyright 2011-2019 Encode OSS Ltd"
# Version synonym
VERSION = __version__
# Header encoding (see RFC5987)
-HTTP_HEADER_ENCODING = 'iso-8859-1'
+HTTP_HEADER_ENCODING = "iso-8859-1"
# Default datetime input and output formats
-ISO_8601 = 'iso-8601'
+ISO_8601 = "iso-8601"
-default_app_config = 'rest_framework.apps.RestFrameworkConfig'
+default_app_config = "rest_framework.apps.RestFrameworkConfig"
class RemovedInDRF310Warning(DeprecationWarning):
diff --git a/rest_framework/apps.py b/rest_framework/apps.py
index f6013eb7e..af2a09038 100644
--- a/rest_framework/apps.py
+++ b/rest_framework/apps.py
@@ -2,7 +2,7 @@ from django.apps import AppConfig
class RestFrameworkConfig(AppConfig):
- name = 'rest_framework'
+ name = "rest_framework"
verbose_name = "Django REST framework"
def ready(self):
diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py
index 25150d525..19b74f115 100644
--- a/rest_framework/authentication.py
+++ b/rest_framework/authentication.py
@@ -20,7 +20,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode.
"""
- auth = request.META.get('HTTP_AUTHORIZATION', b'')
+ auth = request.META.get("HTTP_AUTHORIZATION", b"")
if isinstance(auth, text_type):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
@@ -57,7 +57,8 @@ class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
- www_authenticate_realm = 'api'
+
+ www_authenticate_realm = "api"
def authenticate(self, request):
"""
@@ -66,20 +67,24 @@ class BasicAuthentication(BaseAuthentication):
"""
auth = get_authorization_header(request).split()
- if not auth or auth[0].lower() != b'basic':
+ if not auth or auth[0].lower() != b"basic":
return None
if len(auth) == 1:
- msg = _('Invalid basic header. No credentials provided.')
+ msg = _("Invalid basic header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
- msg = _('Invalid basic header. Credentials string should not contain spaces.')
+ msg = _(
+ "Invalid basic header. Credentials string should not contain spaces."
+ )
raise exceptions.AuthenticationFailed(msg)
try:
- auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
+ auth_parts = (
+ base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(":")
+ )
except (TypeError, UnicodeDecodeError, binascii.Error):
- msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
+ msg = _("Invalid basic header. Credentials not correctly base64 encoded.")
raise exceptions.AuthenticationFailed(msg)
userid, password = auth_parts[0], auth_parts[2]
@@ -90,17 +95,14 @@ class BasicAuthentication(BaseAuthentication):
Authenticate the userid and password against username and password
with optional request for context.
"""
- credentials = {
- get_user_model().USERNAME_FIELD: userid,
- 'password': password
- }
+ credentials = {get_user_model().USERNAME_FIELD: userid, "password": password}
user = authenticate(request=request, **credentials)
if user is None:
- raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
+ raise exceptions.AuthenticationFailed(_("Invalid username/password."))
if not user.is_active:
- raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
+ raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (user, None)
@@ -120,7 +122,7 @@ class SessionAuthentication(BaseAuthentication):
"""
# Get the session-based user from the underlying HttpRequest object
- user = getattr(request._request, 'user', None)
+ user = getattr(request._request, "user", None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
@@ -141,7 +143,7 @@ class SessionAuthentication(BaseAuthentication):
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
- raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
+ raise exceptions.PermissionDenied("CSRF Failed: %s" % reason)
class TokenAuthentication(BaseAuthentication):
@@ -154,13 +156,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""
- keyword = 'Token'
+ keyword = "Token"
model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
+
return Token
"""
@@ -177,16 +180,18 @@ class TokenAuthentication(BaseAuthentication):
return None
if len(auth) == 1:
- msg = _('Invalid token header. No credentials provided.')
+ msg = _("Invalid token header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
- msg = _('Invalid token header. Token string should not contain spaces.')
+ msg = _("Invalid token header. Token string should not contain spaces.")
raise exceptions.AuthenticationFailed(msg)
try:
token = auth[1].decode()
except UnicodeError:
- msg = _('Invalid token header. Token string should not contain invalid characters.')
+ msg = _(
+ "Invalid token header. Token string should not contain invalid characters."
+ )
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token)
@@ -194,12 +199,12 @@ class TokenAuthentication(BaseAuthentication):
def authenticate_credentials(self, key):
model = self.get_model()
try:
- token = model.objects.select_related('user').get(key=key)
+ token = model.objects.select_related("user").get(key=key)
except model.DoesNotExist:
- raise exceptions.AuthenticationFailed(_('Invalid token.'))
+ raise exceptions.AuthenticationFailed(_("Invalid token."))
if not token.user.is_active:
- raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
+ raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (token.user, token)
diff --git a/rest_framework/authtoken/__init__.py b/rest_framework/authtoken/__init__.py
index 82f5b9171..bc19a2e04 100644
--- a/rest_framework/authtoken/__init__.py
+++ b/rest_framework/authtoken/__init__.py
@@ -1 +1 @@
-default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig'
+default_app_config = "rest_framework.authtoken.apps.AuthTokenConfig"
diff --git a/rest_framework/authtoken/admin.py b/rest_framework/authtoken/admin.py
index 1a507249b..f2ca70ec2 100644
--- a/rest_framework/authtoken/admin.py
+++ b/rest_framework/authtoken/admin.py
@@ -4,9 +4,9 @@ from rest_framework.authtoken.models import Token
class TokenAdmin(admin.ModelAdmin):
- list_display = ('key', 'user', 'created')
- fields = ('user',)
- ordering = ('-created',)
+ list_display = ("key", "user", "created")
+ fields = ("user",)
+ ordering = ("-created",)
admin.site.register(Token, TokenAdmin)
diff --git a/rest_framework/authtoken/apps.py b/rest_framework/authtoken/apps.py
index ad01cb404..7b2aac0c6 100644
--- a/rest_framework/authtoken/apps.py
+++ b/rest_framework/authtoken/apps.py
@@ -3,5 +3,5 @@ from django.utils.translation import ugettext_lazy as _
class AuthTokenConfig(AppConfig):
- name = 'rest_framework.authtoken'
+ name = "rest_framework.authtoken"
verbose_name = _("Auth Token")
diff --git a/rest_framework/authtoken/management/commands/drf_create_token.py b/rest_framework/authtoken/management/commands/drf_create_token.py
index 8e06812db..5dc41a97d 100644
--- a/rest_framework/authtoken/management/commands/drf_create_token.py
+++ b/rest_framework/authtoken/management/commands/drf_create_token.py
@@ -3,11 +3,12 @@ from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token
+
UserModel = get_user_model()
class Command(BaseCommand):
- help = 'Create DRF Token for a given user'
+ help = "Create DRF Token for a given user"
def create_user_token(self, username, reset_token):
user = UserModel._default_manager.get_by_natural_key(username)
@@ -19,27 +20,27 @@ class Command(BaseCommand):
return token[0]
def add_arguments(self, parser):
- parser.add_argument('username', type=str)
+ parser.add_argument("username", type=str)
parser.add_argument(
- '-r',
- '--reset',
- action='store_true',
- dest='reset_token',
+ "-r",
+ "--reset",
+ action="store_true",
+ dest="reset_token",
default=False,
- help='Reset existing User token and create a new one',
+ help="Reset existing User token and create a new one",
)
def handle(self, *args, **options):
- username = options['username']
- reset_token = options['reset_token']
+ username = options["username"]
+ reset_token = options["reset_token"]
try:
token = self.create_user_token(username, reset_token)
except UserModel.DoesNotExist:
raise CommandError(
- 'Cannot create the Token: user {0} does not exist'.format(
- username)
+ "Cannot create the Token: user {0} does not exist".format(username)
)
self.stdout.write(
- 'Generated token {0} for user {1}'.format(token.key, username))
+ "Generated token {0} for user {1}".format(token.key, username)
+ )
diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py
index 75780fedf..708bade1a 100644
--- a/rest_framework/authtoken/migrations/0001_initial.py
+++ b/rest_framework/authtoken/migrations/0001_initial.py
@@ -7,20 +7,27 @@ from django.db import migrations, models
class Migration(migrations.Migration):
- dependencies = [
- migrations.swappable_dependency(settings.AUTH_USER_MODEL),
- ]
+ dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
operations = [
migrations.CreateModel(
- name='Token',
+ name="Token",
fields=[
- ('key', models.CharField(primary_key=True, serialize=False, max_length=40)),
- ('created', models.DateTimeField(auto_now_add=True)),
- ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token', on_delete=models.CASCADE)),
+ (
+ "key",
+ models.CharField(primary_key=True, serialize=False, max_length=40),
+ ),
+ ("created", models.DateTimeField(auto_now_add=True)),
+ (
+ "user",
+ models.OneToOneField(
+ to=settings.AUTH_USER_MODEL,
+ related_name="auth_token",
+ on_delete=models.CASCADE,
+ ),
+ ),
],
- options={
- },
+ options={},
bases=(models.Model,),
- ),
+ )
]
diff --git a/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py b/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py
index 9f7e58e22..ac404c764 100644
--- a/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py
+++ b/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py
@@ -7,28 +7,33 @@ from django.db import migrations, models
class Migration(migrations.Migration):
- dependencies = [
- ('authtoken', '0001_initial'),
- ]
+ dependencies = [("authtoken", "0001_initial")]
operations = [
migrations.AlterModelOptions(
- name='token',
- options={'verbose_name_plural': 'Tokens', 'verbose_name': 'Token'},
+ name="token",
+ options={"verbose_name_plural": "Tokens", "verbose_name": "Token"},
),
migrations.AlterField(
- model_name='token',
- name='created',
- field=models.DateTimeField(verbose_name='Created', auto_now_add=True),
+ model_name="token",
+ name="created",
+ field=models.DateTimeField(verbose_name="Created", auto_now_add=True),
),
migrations.AlterField(
- model_name='token',
- name='key',
- field=models.CharField(verbose_name='Key', max_length=40, primary_key=True, serialize=False),
+ model_name="token",
+ name="key",
+ field=models.CharField(
+ verbose_name="Key", max_length=40, primary_key=True, serialize=False
+ ),
),
migrations.AlterField(
- model_name='token',
- name='user',
- field=models.OneToOneField(to=settings.AUTH_USER_MODEL, verbose_name='User', related_name='auth_token', on_delete=models.CASCADE),
+ model_name="token",
+ name="user",
+ field=models.OneToOneField(
+ to=settings.AUTH_USER_MODEL,
+ verbose_name="User",
+ related_name="auth_token",
+ on_delete=models.CASCADE,
+ ),
),
]
diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py
index 7e96eff93..d99a8a524 100644
--- a/rest_framework/authtoken/models.py
+++ b/rest_framework/authtoken/models.py
@@ -12,10 +12,13 @@ class Token(models.Model):
"""
The default authorization token model.
"""
+
key = models.CharField(_("Key"), max_length=40, primary_key=True)
user = models.OneToOneField(
- settings.AUTH_USER_MODEL, related_name='auth_token',
- on_delete=models.CASCADE, verbose_name=_("User")
+ settings.AUTH_USER_MODEL,
+ related_name="auth_token",
+ on_delete=models.CASCADE,
+ verbose_name=_("User"),
)
created = models.DateTimeField(_("Created"), auto_now_add=True)
@@ -25,7 +28,7 @@ class Token(models.Model):
#
# Also see corresponding ticket:
# https://github.com/encode/django-rest-framework/issues/705
- abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
+ abstract = "rest_framework.authtoken" not in settings.INSTALLED_APPS
verbose_name = _("Token")
verbose_name_plural = _("Tokens")
diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py
index e5f46dd66..0587fe0a5 100644
--- a/rest_framework/authtoken/serializers.py
+++ b/rest_framework/authtoken/serializers.py
@@ -7,28 +7,29 @@ from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(label=_("Username"))
password = serializers.CharField(
- label=_("Password"),
- style={'input_type': 'password'},
- trim_whitespace=False
+ label=_("Password"), style={"input_type": "password"}, trim_whitespace=False
)
def validate(self, attrs):
- username = attrs.get('username')
- password = attrs.get('password')
+ username = attrs.get("username")
+ password = attrs.get("password")
if username and password:
- user = authenticate(request=self.context.get('request'),
- username=username, password=password)
+ user = authenticate(
+ request=self.context.get("request"),
+ username=username,
+ password=password,
+ )
# The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication
# backend.)
if not user:
- msg = _('Unable to log in with provided credentials.')
- raise serializers.ValidationError(msg, code='authorization')
+ msg = _("Unable to log in with provided credentials.")
+ raise serializers.ValidationError(msg, code="authorization")
else:
msg = _('Must include "username" and "password".')
- raise serializers.ValidationError(msg, code='authorization')
+ raise serializers.ValidationError(msg, code="authorization")
- attrs['user'] = user
+ attrs["user"] = user
return attrs
diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py
index a8c751d51..f73cbb295 100644
--- a/rest_framework/authtoken/views.py
+++ b/rest_framework/authtoken/views.py
@@ -10,7 +10,7 @@ from rest_framework.views import APIView
class ObtainAuthToken(APIView):
throttle_classes = ()
permission_classes = ()
- parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
+ parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser)
renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer
if coreapi is not None and coreschema is not None:
@@ -19,7 +19,7 @@ class ObtainAuthToken(APIView):
coreapi.Field(
name="username",
required=True,
- location='form',
+ location="form",
schema=coreschema.String(
title="Username",
description="Valid username for authentication",
@@ -28,7 +28,7 @@ class ObtainAuthToken(APIView):
coreapi.Field(
name="password",
required=True,
- location='form',
+ location="form",
schema=coreschema.String(
title="Password",
description="Valid password for authentication",
@@ -39,12 +39,13 @@ class ObtainAuthToken(APIView):
)
def post(self, request, *args, **kwargs):
- serializer = self.serializer_class(data=request.data,
- context={'request': request})
+ serializer = self.serializer_class(
+ data=request.data, context={"request": request}
+ )
serializer.is_valid(raise_exception=True)
- user = serializer.validated_data['user']
+ user = serializer.validated_data["user"]
token, created = Token.objects.get_or_create(user=user)
- return Response({'token': token.key})
+ return Response({"token": token.key})
obtain_auth_token = ObtainAuthToken.as_view()
diff --git a/rest_framework/checks.py b/rest_framework/checks.py
index c1e626018..fe17f0046 100644
--- a/rest_framework/checks.py
+++ b/rest_framework/checks.py
@@ -6,16 +6,17 @@ def pagination_system_check(app_configs, **kwargs):
errors = []
# Use of default page size setting requires a default Paginator class
from rest_framework.settings import api_settings
+
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
errors.append(
Warning(
"You have specified a default PAGE_SIZE pagination rest_framework setting,"
"without specifying also a DEFAULT_PAGINATION_CLASS.",
hint="The default for DEFAULT_PAGINATION_CLASS is None. "
- "In previous versions this was PageNumberPagination. "
- "If you wish to define PAGE_SIZE globally whilst defining "
- "pagination_class on a per-view basis you may silence this check.",
- id="rest_framework.W001"
+ "In previous versions this was PageNumberPagination. "
+ "If you wish to define PAGE_SIZE globally whilst defining "
+ "pagination_class on a per-view basis you may silence this check.",
+ id="rest_framework.W001",
)
)
return errors
diff --git a/rest_framework/compat.py b/rest_framework/compat.py
index 9422e6ad5..9026c1357 100644
--- a/rest_framework/compat.py
+++ b/rest_framework/compat.py
@@ -12,18 +12,16 @@ from django.core import validators
from django.utils import six
from django.views.generic import View
-try:
- # Python 3
- from collections.abc import Mapping, MutableMapping # noqa
-except ImportError:
- # Python 2.7
- from collections import Mapping, MutableMapping # noqa
try:
- from django.urls import ( # noqa
- URLPattern,
- URLResolver,
- )
+ # Python 3
+ from collections.abc import Mapping, MutableMapping # noqa
+except ImportError:
+ # Python 2.7
+ from collections import Mapping, MutableMapping # noqa
+
+try:
+ from django.urls import URLPattern, URLResolver # noqa
except ImportError:
# Will be removed in Django 2.0
from django.urls import ( # noqa
@@ -47,7 +45,7 @@ def get_original_route(urlpattern):
Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This
is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path().
"""
- if hasattr(urlpattern, 'pattern'):
+ if hasattr(urlpattern, "pattern"):
# Django 2.0
return str(urlpattern.pattern)
else:
@@ -60,7 +58,7 @@ def get_regex_pattern(urlpattern):
Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression,
unlike get_original_route above.
"""
- if hasattr(urlpattern, 'pattern'):
+ if hasattr(urlpattern, "pattern"):
# Django 2.0
return urlpattern.pattern.regex.pattern
else:
@@ -69,9 +67,10 @@ def get_regex_pattern(urlpattern):
def is_route_pattern(urlpattern):
- if hasattr(urlpattern, 'pattern'):
+ if hasattr(urlpattern, "pattern"):
# Django 2.0
from django.urls.resolvers import RoutePattern
+
return isinstance(urlpattern.pattern, RoutePattern)
else:
# Django < 2.0
@@ -82,6 +81,7 @@ def make_url_resolver(regex, urlpatterns):
try:
# Django 2.0
from django.urls.resolvers import RegexPattern
+
return URLResolver(RegexPattern(regex), urlpatterns)
except ImportError:
@@ -93,7 +93,7 @@ def unicode_repr(instance):
# Get the repr of an instance, but ensure it is a unicode string
# on both python 3 (already the case) and 2 (not the case).
if six.PY2:
- return repr(instance).decode('utf-8')
+ return repr(instance).decode("utf-8")
return repr(instance)
@@ -102,21 +102,21 @@ def unicode_to_repr(value):
# the Python version. We wrap all our `__repr__` implementations with
# this and then use unicode throughout internally.
if six.PY2:
- return value.encode('utf-8')
+ return value.encode("utf-8")
return value
def unicode_http_header(value):
# Coerce HTTP header value to unicode.
if isinstance(value, bytes):
- return value.decode('iso-8859-1')
+ return value.decode("iso-8859-1")
return value
def distinct(queryset, base):
if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
# distinct analogue for Oracle users
- return base.filter(pk__in=set(queryset.values_list('pk', flat=True)))
+ return base.filter(pk__in=set(queryset.values_list("pk", flat=True)))
return queryset.distinct()
@@ -172,27 +172,27 @@ def is_guardian_installed():
# Guardian 1.5.0, for Django 2.2 is NOT compatible with Python 2.7.
# Remove when dropping PY2.
return False
- return 'guardian' in settings.INSTALLED_APPS
+ return "guardian" in settings.INSTALLED_APPS
# PATCH method is not implemented by Django
-if 'patch' not in View.http_method_names:
- View.http_method_names = View.http_method_names + ['patch']
+if "patch" not in View.http_method_names:
+ View.http_method_names = View.http_method_names + ["patch"]
# Markdown is optional
try:
import markdown
- if markdown.version <= '2.2':
- HEADERID_EXT_PATH = 'headerid'
- LEVEL_PARAM = 'level'
- elif markdown.version < '2.6':
- HEADERID_EXT_PATH = 'markdown.extensions.headerid'
- LEVEL_PARAM = 'level'
+ if markdown.version <= "2.2":
+ HEADERID_EXT_PATH = "headerid"
+ LEVEL_PARAM = "level"
+ elif markdown.version < "2.6":
+ HEADERID_EXT_PATH = "markdown.extensions.headerid"
+ LEVEL_PARAM = "level"
else:
- HEADERID_EXT_PATH = 'markdown.extensions.toc'
- LEVEL_PARAM = 'baselevel'
+ HEADERID_EXT_PATH = "markdown.extensions.toc"
+ LEVEL_PARAM = "baselevel"
def apply_markdown(text):
"""
@@ -200,16 +200,14 @@ try:
of '#' style headers to
.
"""
extensions = [HEADERID_EXT_PATH]
- extension_configs = {
- HEADERID_EXT_PATH: {
- LEVEL_PARAM: '2'
- }
- }
+ extension_configs = {HEADERID_EXT_PATH: {LEVEL_PARAM: "2"}}
md = markdown.Markdown(
extensions=extensions, extension_configs=extension_configs
)
md_filter_add_syntax_highlight(md)
return md.convert(text)
+
+
except ImportError:
apply_markdown = None
markdown = None
@@ -227,7 +225,8 @@ try:
def pygments_css(style):
formatter = HtmlFormatter(style=style)
- return formatter.get_style_defs('.highlight')
+ return formatter.get_style_defs(".highlight")
+
except ImportError:
pygments = None
@@ -238,6 +237,7 @@ except ImportError:
def pygments_css(style):
return None
+
if markdown is not None and pygments is not None:
# starting from this blogpost and modified to support current markdown extensions API
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
@@ -246,8 +246,7 @@ if markdown is not None and pygments is not None:
import re
class CodeBlockPreprocessor(Preprocessor):
- pattern = re.compile(
- r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
+ pattern = re.compile(r"^\s*``` *([^\n]+)\n(.+?)^\s*```", re.M | re.S)
formatter = HtmlFormatter()
@@ -257,17 +256,25 @@ if markdown is not None and pygments is not None:
lexer = get_lexer_by_name(m.group(1))
except (ValueError, NameError):
lexer = TextLexer()
- code = m.group(2).replace('\t', ' ')
+ code = m.group(2).replace("\t", " ")
code = pygments.highlight(code, lexer, self.formatter)
- code = code.replace('\n\n', '\n \n').replace('\n', '
').replace('\\@', '@')
- return '\n\n%s\n\n' % code
+ code = (
+ code.replace("\n\n", "\n \n")
+ .replace("\n", "
")
+ .replace("\\@", "@")
+ )
+ return "\n\n%s\n\n" % code
+
ret = self.pattern.sub(repl, "\n".join(lines))
return ret.split("\n")
def md_filter_add_syntax_highlight(md):
- md.preprocessors.add('highlight', CodeBlockPreprocessor(), "_begin")
+ md.preprocessors.add("highlight", CodeBlockPreprocessor(), "_begin")
return True
+
+
else:
+
def md_filter_add_syntax_highlight(md):
return False
@@ -276,7 +283,8 @@ else:
try:
from django.urls import include, path, re_path, register_converter # noqa
except ImportError:
- from django.conf.urls import include, url # noqa
+ from django.conf.urls import include, url # noqa
+
path = None
register_converter = None
re_path = url
@@ -285,13 +293,13 @@ except ImportError:
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767
if six.PY3:
- SHORT_SEPARATORS = (',', ':')
- LONG_SEPARATORS = (', ', ': ')
- INDENT_SEPARATORS = (',', ': ')
+ SHORT_SEPARATORS = (",", ":")
+ LONG_SEPARATORS = (", ", ": ")
+ INDENT_SEPARATORS = (",", ": ")
else:
- SHORT_SEPARATORS = (b',', b':')
- LONG_SEPARATORS = (b', ', b': ')
- INDENT_SEPARATORS = (b',', b': ')
+ SHORT_SEPARATORS = (b",", b":")
+ LONG_SEPARATORS = (b", ", b": ")
+ INDENT_SEPARATORS = (b",", b": ")
class CustomValidatorMessage(object):
@@ -303,7 +311,7 @@ class CustomValidatorMessage(object):
"""
def __init__(self, *args, **kwargs):
- self.message = kwargs.pop('message', self.message)
+ self.message = kwargs.pop("message", self.message)
super(CustomValidatorMessage, self).__init__(*args, **kwargs)
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 30bfcc4e5..6999e6a79 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -23,14 +23,14 @@ def api_view(http_method_names=None):
Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument.
"""
- http_method_names = ['GET'] if (http_method_names is None) else http_method_names
+ http_method_names = ["GET"] if (http_method_names is None) else http_method_names
def decorator(func):
WrappedAPIView = type(
- six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
+ six.PY3 and "WrappedAPIView" or b"WrappedAPIView",
(APIView,),
- {'__doc__': func.__doc__}
+ {"__doc__": func.__doc__},
)
# Note, the above allows us to set the docstring.
@@ -41,15 +41,20 @@ def api_view(http_method_names=None):
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names)
- assert not(isinstance(http_method_names, types.FunctionType)), \
- '@api_view missing list of allowed HTTP methods'
+ assert not (
+ isinstance(http_method_names, types.FunctionType)
+ ), "@api_view missing list of allowed HTTP methods"
# api_view applied with eg. string instead of list of strings
- assert isinstance(http_method_names, (list, tuple)), \
- '@api_view expected a list of strings, received %s' % type(http_method_names).__name__
+ assert isinstance(http_method_names, (list, tuple)), (
+ "@api_view expected a list of strings, received %s"
+ % type(http_method_names).__name__
+ )
- allowed_methods = set(http_method_names) | {'options'}
- WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
+ allowed_methods = set(http_method_names) | {"options"}
+ WrappedAPIView.http_method_names = [
+ method.lower() for method in allowed_methods
+ ]
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
@@ -60,23 +65,27 @@ def api_view(http_method_names=None):
WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__
- WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
- APIView.renderer_classes)
+ WrappedAPIView.renderer_classes = getattr(
+ func, "renderer_classes", APIView.renderer_classes
+ )
- WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
- APIView.parser_classes)
+ WrappedAPIView.parser_classes = getattr(
+ func, "parser_classes", APIView.parser_classes
+ )
- WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
- APIView.authentication_classes)
+ WrappedAPIView.authentication_classes = getattr(
+ func, "authentication_classes", APIView.authentication_classes
+ )
- WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
- APIView.throttle_classes)
+ WrappedAPIView.throttle_classes = getattr(
+ func, "throttle_classes", APIView.throttle_classes
+ )
- WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
- APIView.permission_classes)
+ WrappedAPIView.permission_classes = getattr(
+ func, "permission_classes", APIView.permission_classes
+ )
- WrappedAPIView.schema = getattr(func, 'schema',
- APIView.schema)
+ WrappedAPIView.schema = getattr(func, "schema", APIView.schema)
return WrappedAPIView.as_view()
@@ -87,6 +96,7 @@ def renderer_classes(renderer_classes):
def decorator(func):
func.renderer_classes = renderer_classes
return func
+
return decorator
@@ -94,6 +104,7 @@ def parser_classes(parser_classes):
def decorator(func):
func.parser_classes = parser_classes
return func
+
return decorator
@@ -101,6 +112,7 @@ def authentication_classes(authentication_classes):
def decorator(func):
func.authentication_classes = authentication_classes
return func
+
return decorator
@@ -108,6 +120,7 @@ def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
+
return decorator
@@ -115,6 +128,7 @@ def permission_classes(permission_classes):
def decorator(func):
func.permission_classes = permission_classes
return func
+
return decorator
@@ -122,6 +136,7 @@ def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
+
return decorator
@@ -132,15 +147,13 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
Set the `detail` boolean to determine if this action should apply to
instance/detail requests or collection/list requests.
"""
- methods = ['get'] if (methods is None) else methods
+ methods = ["get"] if (methods is None) else methods
methods = [method.lower() for method in methods]
- assert detail is not None, (
- "@action() missing required argument: 'detail'"
- )
+ assert detail is not None, "@action() missing required argument: 'detail'"
# name and suffix are mutually exclusive
- if 'name' in kwargs and 'suffix' in kwargs:
+ if "name" in kwargs and "suffix" in kwargs:
raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
def decorator(func):
@@ -148,15 +161,16 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
func.detail = detail
func.url_path = url_path if url_path else func.__name__
- func.url_name = url_name if url_name else func.__name__.replace('_', '-')
+ func.url_name = url_name if url_name else func.__name__.replace("_", "-")
func.kwargs = kwargs
# Set descriptive arguments for viewsets
- if 'name' not in kwargs and 'suffix' not in kwargs:
- func.kwargs['name'] = pretty_name(func.__name__)
- func.kwargs['description'] = func.__doc__ or None
+ if "name" not in kwargs and "suffix" not in kwargs:
+ func.kwargs["name"] = pretty_name(func.__name__)
+ func.kwargs["description"] = func.__doc__ or None
return func
+
return decorator
@@ -184,39 +198,42 @@ class MethodMapper(dict):
self[method] = self.action.__name__
def _map(self, method, func):
- assert method not in self, (
- "Method '%s' has already been mapped to '.%s'." % (method, self[method]))
+ assert method not in self, "Method '%s' has already been mapped to '.%s'." % (
+ method,
+ self[method],
+ )
assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You "
- "cannot use the same method name for each mapping declaration.")
+ "cannot use the same method name for each mapping declaration."
+ )
self[method] = func.__name__
return func
def get(self, func):
- return self._map('get', func)
+ return self._map("get", func)
def post(self, func):
- return self._map('post', func)
+ return self._map("post", func)
def put(self, func):
- return self._map('put', func)
+ return self._map("put", func)
def patch(self, func):
- return self._map('patch', func)
+ return self._map("patch", func)
def delete(self, func):
- return self._map('delete', func)
+ return self._map("delete", func)
def head(self, func):
- return self._map('head', func)
+ return self._map("head", func)
def options(self, func):
- return self._map('options', func)
+ return self._map("options", func)
def trace(self, func):
- return self._map('trace', func)
+ return self._map("trace", func)
def detail_route(methods=None, **kwargs):
@@ -226,14 +243,16 @@ def detail_route(methods=None, **kwargs):
warnings.warn(
"`detail_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=True)` instead.",
- RemovedInDRF310Warning, stacklevel=2
+ RemovedInDRF310Warning,
+ stacklevel=2,
)
def decorator(func):
func = action(methods, detail=True, **kwargs)(func)
- if 'url_name' not in kwargs:
- func.url_name = func.url_path.replace('_', '-')
+ if "url_name" not in kwargs:
+ func.url_name = func.url_path.replace("_", "-")
return func
+
return decorator
@@ -244,12 +263,14 @@ def list_route(methods=None, **kwargs):
warnings.warn(
"`list_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=False)` instead.",
- RemovedInDRF310Warning, stacklevel=2
+ RemovedInDRF310Warning,
+ stacklevel=2,
)
def decorator(func):
func = action(methods, detail=False, **kwargs)(func)
- if 'url_name' not in kwargs:
- func.url_name = func.url_path.replace('_', '-')
+ if "url_name" not in kwargs:
+ func.url_name = func.url_path.replace("_", "-")
return func
+
return decorator
diff --git a/rest_framework/documentation.py b/rest_framework/documentation.py
index 3a78bb341..f86d91c93 100644
--- a/rest_framework/documentation.py
+++ b/rest_framework/documentation.py
@@ -1,18 +1,25 @@
from django.conf.urls import include, url
from rest_framework.renderers import (
- CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
+ CoreJSONRenderer,
+ DocumentationRenderer,
+ SchemaJSRenderer,
)
from rest_framework.schemas import SchemaGenerator, get_schema_view
from rest_framework.settings import api_settings
def get_docs_view(
- title=None, description=None, schema_url=None, public=True,
- patterns=None, generator_class=SchemaGenerator,
- authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
- permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
- renderer_classes=None):
+ title=None,
+ description=None,
+ schema_url=None,
+ public=True,
+ patterns=None,
+ generator_class=SchemaGenerator,
+ authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
+ permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
+ renderer_classes=None,
+):
if renderer_classes is None:
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
@@ -31,10 +38,15 @@ def get_docs_view(
def get_schemajs_view(
- title=None, description=None, schema_url=None, public=True,
- patterns=None, generator_class=SchemaGenerator,
- authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
- permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
+ title=None,
+ description=None,
+ schema_url=None,
+ public=True,
+ patterns=None,
+ generator_class=SchemaGenerator,
+ authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
+ permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
+):
renderer_classes = [SchemaJSRenderer]
return get_schema_view(
@@ -51,11 +63,16 @@ def get_schemajs_view(
def include_docs_urls(
- title=None, description=None, schema_url=None, public=True,
- patterns=None, generator_class=SchemaGenerator,
- authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
- permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
- renderer_classes=None):
+ title=None,
+ description=None,
+ schema_url=None,
+ public=True,
+ patterns=None,
+ generator_class=SchemaGenerator,
+ authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
+ permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
+ renderer_classes=None,
+):
docs_view = get_docs_view(
title=title,
description=description,
@@ -78,7 +95,7 @@ def include_docs_urls(
permission_classes=permission_classes,
)
urls = [
- url(r'^$', docs_view, name='docs-index'),
- url(r'^schema.js$', schema_js_view, name='schema-js')
+ url(r"^$", docs_view, name="docs-index"),
+ url(r"^schema.js$", schema_js_view, name="schema-js"),
]
- return include((urls, 'api-docs'), namespace='api-docs')
+ return include((urls, "api-docs"), namespace="api-docs")
diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py
index f79b16129..1dc3feca3 100644
--- a/rest_framework/exceptions.py
+++ b/rest_framework/exceptions.py
@@ -11,8 +11,7 @@ import math
from django.http import JsonResponse
from django.utils import six
from django.utils.encoding import force_text
-from django.utils.translation import ugettext_lazy as _
-from django.utils.translation import ungettext
+from django.utils.translation import ugettext_lazy as _, ungettext
from rest_framework import status
from rest_framework.compat import unicode_to_repr
@@ -25,23 +24,20 @@ def _get_error_details(data, default_code=None):
lazy translation strings or strings into `ErrorDetail`.
"""
if isinstance(data, list):
- ret = [
- _get_error_details(item, default_code) for item in data
- ]
+ ret = [_get_error_details(item, default_code) for item in data]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
- key: _get_error_details(value, default_code)
- for key, value in data.items()
+ key: _get_error_details(value, default_code) for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
text = force_text(data)
- code = getattr(data, 'code', default_code)
+ code = getattr(data, "code", default_code)
return ErrorDetail(text, code)
@@ -58,16 +54,14 @@ def _get_full_details(detail):
return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()}
- return {
- 'message': detail,
- 'code': detail.code
- }
+ return {"message": detail, "code": detail.code}
class ErrorDetail(six.text_type):
"""
A string-like object that can additionally have a code.
"""
+
code = None
def __new__(cls, string, code=None):
@@ -86,10 +80,9 @@ class ErrorDetail(six.text_type):
return not self.__eq__(other)
def __repr__(self):
- return unicode_to_repr('ErrorDetail(string=%r, code=%r)' % (
- six.text_type(self),
- self.code,
- ))
+ return unicode_to_repr(
+ "ErrorDetail(string=%r, code=%r)" % (six.text_type(self), self.code)
+ )
def __hash__(self):
return hash(str(self))
@@ -100,9 +93,10 @@ class APIException(Exception):
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
+
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
- default_detail = _('A server error occurred.')
- default_code = 'error'
+ default_detail = _("A server error occurred.")
+ default_code = "error"
def __init__(self, detail=None, code=None):
if detail is None:
@@ -139,10 +133,11 @@ class APIException(Exception):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
+
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
- default_detail = _('Invalid input.')
- default_code = 'invalid'
+ default_detail = _("Invalid input.")
+ default_code = "invalid"
def __init__(self, detail=None, code=None):
if detail is None:
@@ -160,38 +155,38 @@ class ValidationError(APIException):
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
- default_detail = _('Malformed request.')
- default_code = 'parse_error'
+ default_detail = _("Malformed request.")
+ default_code = "parse_error"
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
- default_detail = _('Incorrect authentication credentials.')
- default_code = 'authentication_failed'
+ default_detail = _("Incorrect authentication credentials.")
+ default_code = "authentication_failed"
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
- default_detail = _('Authentication credentials were not provided.')
- default_code = 'not_authenticated'
+ default_detail = _("Authentication credentials were not provided.")
+ default_code = "not_authenticated"
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
- default_detail = _('You do not have permission to perform this action.')
- default_code = 'permission_denied'
+ default_detail = _("You do not have permission to perform this action.")
+ default_code = "permission_denied"
class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND
- default_detail = _('Not found.')
- default_code = 'not_found'
+ default_detail = _("Not found.")
+ default_code = "not_found"
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.')
- default_code = 'method_not_allowed'
+ default_code = "method_not_allowed"
def __init__(self, method, detail=None, code=None):
if detail is None:
@@ -201,8 +196,8 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
- default_detail = _('Could not satisfy the request Accept header.')
- default_code = 'not_acceptable'
+ default_detail = _("Could not satisfy the request Accept header.")
+ default_code = "not_acceptable"
def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers
@@ -212,7 +207,7 @@ class NotAcceptable(APIException):
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.')
- default_code = 'unsupported_media_type'
+ default_code = "unsupported_media_type"
def __init__(self, media_type, detail=None, code=None):
if detail is None:
@@ -222,21 +217,28 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
- default_detail = _('Request was throttled.')
- extra_detail_singular = 'Expected available in {wait} second.'
- extra_detail_plural = 'Expected available in {wait} seconds.'
- default_code = 'throttled'
+ default_detail = _("Request was throttled.")
+ extra_detail_singular = "Expected available in {wait} second."
+ extra_detail_plural = "Expected available in {wait} seconds."
+ default_code = "throttled"
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_text(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
- detail = ' '.join((
- detail,
- force_text(ungettext(self.extra_detail_singular.format(wait=wait),
- self.extra_detail_plural.format(wait=wait),
- wait))))
+ detail = " ".join(
+ (
+ detail,
+ force_text(
+ ungettext(
+ self.extra_detail_singular.format(wait=wait),
+ self.extra_detail_plural.format(wait=wait),
+ wait,
+ )
+ ),
+ )
+ )
self.wait = wait
super(Throttled, self).__init__(detail, code)
@@ -245,9 +247,7 @@ def server_error(request, *args, **kwargs):
"""
Generic 500 error handler.
"""
- data = {
- 'error': 'Server Error (500)'
- }
+ data = {"error": "Server Error (500)"}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@@ -255,7 +255,5 @@ def bad_request(request, exception, *args, **kwargs):
"""
Generic 400 error handler.
"""
- data = {
- 'error': 'Bad Request (400)'
- }
+ data = {"error": "Bad Request (400)"}
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index c8f65db0e..998d4781f 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -10,16 +10,26 @@ import uuid
from collections import OrderedDict
from django.conf import settings
-from django.core.exceptions import ObjectDoesNotExist
-from django.core.exceptions import ValidationError as DjangoValidationError
-from django.core.validators import (
- EmailValidator, RegexValidator, URLValidator, ip_address_validators
+from django.core.exceptions import (
+ ObjectDoesNotExist,
+ ValidationError as DjangoValidationError,
+)
+from django.core.validators import (
+ EmailValidator,
+ RegexValidator,
+ URLValidator,
+ ip_address_validators,
+)
+from django.forms import (
+ FilePathField as DjangoFilePathField,
+ ImageField as DjangoImageField,
)
-from django.forms import FilePathField as DjangoFilePathField
-from django.forms import ImageField as DjangoImageField
from django.utils import six, timezone
from django.utils.dateparse import (
- parse_date, parse_datetime, parse_duration, parse_time
+ parse_date,
+ parse_datetime,
+ parse_duration,
+ parse_time,
)
from django.utils.duration import duration_string
from django.utils.encoding import is_protected_type, smart_text
@@ -32,9 +42,14 @@ from pytz.exceptions import InvalidTimeError
from rest_framework import ISO_8601
from rest_framework.compat import (
- Mapping, MaxLengthValidator, MaxValueValidator, MinLengthValidator,
- MinValueValidator, ProhibitNullCharactersValidator, unicode_repr,
- unicode_to_repr
+ Mapping,
+ MaxLengthValidator,
+ MaxValueValidator,
+ MinLengthValidator,
+ MinValueValidator,
+ ProhibitNullCharactersValidator,
+ unicode_repr,
+ unicode_to_repr,
)
from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings
@@ -48,27 +63,35 @@ class empty:
It is required because `None` may be a valid input or output value.
"""
+
pass
if six.PY3:
+
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
- if not (inspect.isfunction(obj) or inspect.ismethod(obj) or isinstance(obj, functools.partial)):
+ if not (
+ inspect.isfunction(obj)
+ or inspect.ismethod(obj)
+ or isinstance(obj, functools.partial)
+ ):
return False
sig = inspect.signature(obj)
params = sig.parameters.values()
return all(
- param.kind == param.VAR_POSITIONAL or
- param.kind == param.VAR_KEYWORD or
- param.default != param.empty
+ param.kind == param.VAR_POSITIONAL
+ or param.kind == param.VAR_KEYWORD
+ or param.default != param.empty
for param in params
)
+
else:
+
def is_simple_callable(obj):
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
@@ -108,7 +131,11 @@ def get_attribute(instance, attrs):
# If we raised an Attribute or KeyError here it'd get treated
# as an omitted field in `Field.get_attribute()`. Instead we
# raise a ValueError to ensure the exception is not masked.
- raise ValueError('Exception raised in callable attribute "{0}"; original exception was: {1}'.format(attr, exc))
+ raise ValueError(
+ 'Exception raised in callable attribute "{0}"; original exception was: {1}'.format(
+ attr, exc
+ )
+ )
return instance
@@ -185,6 +212,7 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None):
"""
Helper function for options and option groups in templates.
"""
+
class StartOptionGroup(object):
start_option_group = True
end_option_group = False
@@ -225,7 +253,7 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None):
if cutoff and count >= cutoff and cutoff_text:
cutoff_text = cutoff_text.format(count=cutoff)
- yield Option(value='n/a', display_text=cutoff_text, disabled=True)
+ yield Option(value="n/a", display_text=cutoff_text, disabled=True)
def get_error_detail(exc_info):
@@ -233,21 +261,27 @@ def get_error_detail(exc_info):
Given a Django ValidationError, return a list of ErrorDetail,
with the `code` populated.
"""
- code = getattr(exc_info, 'code', None) or 'invalid'
+ code = getattr(exc_info, "code", None) or "invalid"
try:
error_dict = exc_info.error_dict
except AttributeError:
return [
- ErrorDetail(error.message % (error.params or ()),
- code=error.code if error.code else code)
- for error in exc_info.error_list]
+ ErrorDetail(
+ error.message % (error.params or ()),
+ code=error.code if error.code else code,
+ )
+ for error in exc_info.error_list
+ ]
return {
k: [
- ErrorDetail(error.message % (error.params or ()),
- code=error.code if error.code else code)
+ ErrorDetail(
+ error.message % (error.params or ()),
+ code=error.code if error.code else code,
+ )
for error in errors
- ] for k, errors in error_dict.items()
+ ]
+ for k, errors in error_dict.items()
}
@@ -257,12 +291,17 @@ class CreateOnlyDefault(object):
for create operations, but that do not return any value for update
operations.
"""
+
def __init__(self, default):
self.default = default
def set_context(self, serializer_field):
self.is_update = serializer_field.parent.instance is not None
- if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update:
+ if (
+ callable(self.default)
+ and hasattr(self.default, "set_context")
+ and not self.is_update
+ ):
self.default.set_context(serializer_field)
def __call__(self):
@@ -274,34 +313,34 @@ class CreateOnlyDefault(object):
def __repr__(self):
return unicode_to_repr(
- '%s(%s)' % (self.__class__.__name__, unicode_repr(self.default))
+ "%s(%s)" % (self.__class__.__name__, unicode_repr(self.default))
)
class CurrentUserDefault(object):
def set_context(self, serializer_field):
- self.user = serializer_field.context['request'].user
+ self.user = serializer_field.context["request"].user
def __call__(self):
return self.user
def __repr__(self):
- return unicode_to_repr('%s()' % self.__class__.__name__)
+ return unicode_to_repr("%s()" % self.__class__.__name__)
class SkipField(Exception):
pass
-REGEX_TYPE = type(re.compile(''))
+REGEX_TYPE = type(re.compile(""))
-NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
-NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
-NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
-USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField'
+NOT_READ_ONLY_WRITE_ONLY = "May not set both `read_only` and `write_only`"
+NOT_READ_ONLY_REQUIRED = "May not set both `read_only` and `required`"
+NOT_REQUIRED_DEFAULT = "May not set both `required` and `default`"
+USE_READONLYFIELD = "Field(read_only=True) should be ReadOnlyField"
MISSING_ERROR_MESSAGE = (
- 'ValidationError raised by `{class_name}`, but error key `{key}` does '
- 'not exist in the `error_messages` dictionary.'
+ "ValidationError raised by `{class_name}`, but error key `{key}` does "
+ "not exist in the `error_messages` dictionary."
)
@@ -309,17 +348,28 @@ class Field(object):
_creation_counter = 0
default_error_messages = {
- 'required': _('This field is required.'),
- 'null': _('This field may not be null.')
+ "required": _("This field is required."),
+ "null": _("This field may not be null."),
}
default_validators = []
default_empty_html = empty
initial = None
- def __init__(self, read_only=False, write_only=False,
- required=None, default=empty, initial=empty, source=None,
- label=None, help_text=None, style=None,
- error_messages=None, validators=None, allow_null=False):
+ def __init__(
+ self,
+ read_only=False,
+ write_only=False,
+ required=None,
+ default=empty,
+ initial=empty,
+ source=None,
+ label=None,
+ help_text=None,
+ style=None,
+ error_messages=None,
+ validators=None,
+ allow_null=False,
+ ):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
@@ -358,7 +408,7 @@ class Field(object):
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
- messages.update(getattr(cls, 'default_error_messages', {}))
+ messages.update(getattr(cls, "default_error_messages", {}))
messages.update(error_messages or {})
self.error_messages = messages
@@ -374,8 +424,8 @@ class Field(object):
assert self.source != field_name, (
"It is redundant to specify `source='%s'` on field '%s' in "
"serializer '%s', because it is the same as the field name. "
- "Remove the `source` keyword argument." %
- (field_name, self.__class__.__name__, parent.__class__.__name__)
+ "Remove the `source` keyword argument."
+ % (field_name, self.__class__.__name__, parent.__class__.__name__)
)
self.field_name = field_name
@@ -383,7 +433,7 @@ class Field(object):
# `self.label` should default to being based on the field name.
if self.label is None:
- self.label = field_name.replace('_', ' ').capitalize()
+ self.label = field_name.replace("_", " ").capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
@@ -391,16 +441,16 @@ class Field(object):
# self.source_attrs is a list of attributes that need to be looked up
# when serializing the instance, or populating the validated data.
- if self.source == '*':
+ if self.source == "*":
self.source_attrs = []
else:
- self.source_attrs = self.source.split('.')
+ self.source_attrs = self.source.split(".")
# .validators is a lazily loaded property, that gets its default
# value from `get_validators`.
@property
def validators(self):
- if not hasattr(self, '_validators'):
+ if not hasattr(self, "_validators"):
self._validators = self.get_validators()
return self._validators
@@ -429,18 +479,18 @@ class Field(object):
# HTML forms will represent empty fields as '', and cannot
# represent None or False values directly.
if self.field_name not in dictionary:
- if getattr(self.root, 'partial', False):
+ if getattr(self.root, "partial", False):
return empty
return self.default_empty_html
ret = dictionary[self.field_name]
- if ret == '' and self.allow_null:
+ if ret == "" and self.allow_null:
# If the field is blank, and null is a valid value then
# determine if we should use null instead.
- return '' if getattr(self, 'allow_blank', False) else None
- elif ret == '' and not self.required:
+ return "" if getattr(self, "allow_blank", False) else None
+ elif ret == "" and not self.required:
# If the field is blank, and emptiness is valid then
# determine if we should use emptiness instead.
- return '' if getattr(self, 'allow_blank', False) else empty
+ return "" if getattr(self, "allow_blank", False) else empty
return ret
return dictionary.get(self.field_name, empty)
@@ -459,16 +509,16 @@ class Field(object):
if not self.required:
raise SkipField()
msg = (
- 'Got {exc_type} when attempting to get a value for field '
- '`{field}` on serializer `{serializer}`.\nThe serializer '
- 'field might be named incorrectly and not match '
- 'any attribute or key on the `{instance}` instance.\n'
- 'Original exception text was: {exc}.'.format(
+ "Got {exc_type} when attempting to get a value for field "
+ "`{field}` on serializer `{serializer}`.\nThe serializer "
+ "field might be named incorrectly and not match "
+ "any attribute or key on the `{instance}` instance.\n"
+ "Original exception text was: {exc}.".format(
exc_type=type(exc).__name__,
field=self.field_name,
serializer=self.parent.__class__.__name__,
instance=instance.__class__.__name__,
- exc=exc
+ exc=exc,
)
)
raise type(exc)(msg)
@@ -482,11 +532,11 @@ class Field(object):
raise `SkipField`, indicating that no value should be set in the
validated data for this field.
"""
- if self.default is empty or getattr(self.root, 'partial', False):
+ if self.default is empty or getattr(self.root, "partial", False):
# No default, or this is a partial update.
raise SkipField()
if callable(self.default):
- if hasattr(self.default, 'set_context'):
+ if hasattr(self.default, "set_context"):
self.default.set_context(self)
return self.default()
return self.default
@@ -506,15 +556,15 @@ class Field(object):
return (True, self.get_default())
if data is empty:
- if getattr(self.root, 'partial', False):
+ if getattr(self.root, "partial", False):
raise SkipField()
if self.required:
- self.fail('required')
+ self.fail("required")
return (True, self.get_default())
if data is None:
if not self.allow_null:
- self.fail('null')
+ self.fail("null")
return (True, None)
return (False, data)
@@ -543,7 +593,7 @@ class Field(object):
"""
errors = []
for validator in self.validators:
- if hasattr(validator, 'set_context'):
+ if hasattr(validator, "set_context"):
validator.set_context(self)
try:
@@ -565,7 +615,7 @@ class Field(object):
Transform the *incoming* primitive data into a native value.
"""
raise NotImplementedError(
- '{cls}.to_internal_value() must be implemented.'.format(
+ "{cls}.to_internal_value() must be implemented.".format(
cls=self.__class__.__name__
)
)
@@ -575,11 +625,10 @@ class Field(object):
Transform the *outgoing* native value into primitive data.
"""
raise NotImplementedError(
- '{cls}.to_representation() must be implemented for field '
- '{field_name}. If you do not need to support write operations '
- 'you probably want to subclass `ReadOnlyField` instead.'.format(
- cls=self.__class__.__name__,
- field_name=self.field_name,
+ "{cls}.to_representation() must be implemented for field "
+ "{field_name}. If you do not need to support write operations "
+ "you probably want to subclass `ReadOnlyField` instead.".format(
+ cls=self.__class__.__name__, field_name=self.field_name
)
)
@@ -611,7 +660,7 @@ class Field(object):
"""
Returns the context as passed to the root serializer on initialization.
"""
- return getattr(self.root, '_context', {})
+ return getattr(self.root, "_context", {})
def __new__(cls, *args, **kwargs):
"""
@@ -636,7 +685,9 @@ class Field(object):
for item in self._args
]
kwargs = {
- key: (copy.deepcopy(value) if (key not in ('validators', 'regex')) else value)
+ key: (
+ copy.deepcopy(value) if (key not in ("validators", "regex")) else value
+ )
for key, value in self._kwargs.items()
}
return self.__class__(*args, **kwargs)
@@ -652,29 +703,47 @@ class Field(object):
# Boolean types...
+
class BooleanField(Field):
- default_error_messages = {
- 'invalid': _('Must be a valid boolean.')
- }
+ default_error_messages = {"invalid": _("Must be a valid boolean.")}
default_empty_html = False
initial = False
TRUE_VALUES = {
- 't', 'T',
- 'y', 'Y', 'yes', 'YES',
- 'true', 'True', 'TRUE',
- 'on', 'On', 'ON',
- '1', 1,
- True
+ "t",
+ "T",
+ "y",
+ "Y",
+ "yes",
+ "YES",
+ "true",
+ "True",
+ "TRUE",
+ "on",
+ "On",
+ "ON",
+ "1",
+ 1,
+ True,
}
FALSE_VALUES = {
- 'f', 'F',
- 'n', 'N', 'no', 'NO',
- 'false', 'False', 'FALSE',
- 'off', 'Off', 'OFF',
- '0', 0, 0.0,
- False
+ "f",
+ "F",
+ "n",
+ "N",
+ "no",
+ "NO",
+ "false",
+ "False",
+ "FALSE",
+ "off",
+ "Off",
+ "OFF",
+ "0",
+ 0,
+ 0.0,
+ False,
}
- NULL_VALUES = {'null', 'Null', 'NULL', '', None}
+ NULL_VALUES = {"null", "Null", "NULL", "", None}
def to_internal_value(self, data):
try:
@@ -686,7 +755,7 @@ class BooleanField(Field):
return None
except TypeError: # Input is an unhashable type
pass
- self.fail('invalid', input=data)
+ self.fail("invalid", input=data)
def to_representation(self, value):
if value in self.TRUE_VALUES:
@@ -699,31 +768,48 @@ class BooleanField(Field):
class NullBooleanField(Field):
- default_error_messages = {
- 'invalid': _('Must be a valid boolean.')
- }
+ default_error_messages = {"invalid": _("Must be a valid boolean.")}
initial = None
TRUE_VALUES = {
- 't', 'T',
- 'y', 'Y', 'yes', 'YES',
- 'true', 'True', 'TRUE',
- 'on', 'On', 'ON',
- '1', 1,
- True
+ "t",
+ "T",
+ "y",
+ "Y",
+ "yes",
+ "YES",
+ "true",
+ "True",
+ "TRUE",
+ "on",
+ "On",
+ "ON",
+ "1",
+ 1,
+ True,
}
FALSE_VALUES = {
- 'f', 'F',
- 'n', 'N', 'no', 'NO',
- 'false', 'False', 'FALSE',
- 'off', 'Off', 'OFF',
- '0', 0, 0.0,
- False
+ "f",
+ "F",
+ "n",
+ "N",
+ "no",
+ "NO",
+ "false",
+ "False",
+ "FALSE",
+ "off",
+ "Off",
+ "OFF",
+ "0",
+ 0,
+ 0.0,
+ False,
}
- NULL_VALUES = {'null', 'Null', 'NULL', '', None}
+ NULL_VALUES = {"null", "Null", "NULL", "", None}
def __init__(self, **kwargs):
- assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.'
- kwargs['allow_null'] = True
+ assert "allow_null" not in kwargs, "`allow_null` is not a valid option."
+ kwargs["allow_null"] = True
super(NullBooleanField, self).__init__(**kwargs)
def to_internal_value(self, data):
@@ -736,7 +822,7 @@ class NullBooleanField(Field):
return None
except TypeError: # Input is an unhashable type
pass
- self.fail('invalid', input=data)
+ self.fail("invalid", input=data)
def to_representation(self, value):
if value in self.NULL_VALUES:
@@ -750,33 +836,32 @@ class NullBooleanField(Field):
# String types...
+
class CharField(Field):
default_error_messages = {
- 'invalid': _('Not a valid string.'),
- 'blank': _('This field may not be blank.'),
- 'max_length': _('Ensure this field has no more than {max_length} characters.'),
- 'min_length': _('Ensure this field has at least {min_length} characters.'),
+ "invalid": _("Not a valid string."),
+ "blank": _("This field may not be blank."),
+ "max_length": _("Ensure this field has no more than {max_length} characters."),
+ "min_length": _("Ensure this field has at least {min_length} characters."),
}
- initial = ''
+ initial = ""
def __init__(self, **kwargs):
- self.allow_blank = kwargs.pop('allow_blank', False)
- self.trim_whitespace = kwargs.pop('trim_whitespace', True)
- self.max_length = kwargs.pop('max_length', None)
- self.min_length = kwargs.pop('min_length', None)
+ self.allow_blank = kwargs.pop("allow_blank", False)
+ self.trim_whitespace = kwargs.pop("trim_whitespace", True)
+ self.max_length = kwargs.pop("max_length", None)
+ self.min_length = kwargs.pop("min_length", None)
super(CharField, self).__init__(**kwargs)
if self.max_length is not None:
- message = lazy(
- self.error_messages['max_length'].format,
- six.text_type)(max_length=self.max_length)
- self.validators.append(
- MaxLengthValidator(self.max_length, message=message))
+ message = lazy(self.error_messages["max_length"].format, six.text_type)(
+ max_length=self.max_length
+ )
+ self.validators.append(MaxLengthValidator(self.max_length, message=message))
if self.min_length is not None:
- message = lazy(
- self.error_messages['min_length'].format,
- six.text_type)(min_length=self.min_length)
- self.validators.append(
- MinLengthValidator(self.min_length, message=message))
+ message = lazy(self.error_messages["min_length"].format, six.text_type)(
+ min_length=self.min_length
+ )
+ self.validators.append(MinLengthValidator(self.min_length, message=message))
# ProhibitNullCharactersValidator is None on Django < 2.0
if ProhibitNullCharactersValidator is not None:
@@ -786,18 +871,20 @@ class CharField(Field):
# Test for the empty string here so that it does not get validated,
# and so that subclasses do not need to handle it explicitly
# inside the `to_internal_value()` method.
- if data == '' or (self.trim_whitespace and six.text_type(data).strip() == ''):
+ if data == "" or (self.trim_whitespace and six.text_type(data).strip() == ""):
if not self.allow_blank:
- self.fail('blank')
- return ''
+ self.fail("blank")
+ return ""
return super(CharField, self).run_validation(data)
def to_internal_value(self, data):
# We're lenient with allowing basic numerics to be coerced into strings,
# but other types should fail. Eg. unclear if booleans should represent as `true` or `True`,
# and composites such as lists are likely user error.
- if isinstance(data, bool) or not isinstance(data, six.string_types + six.integer_types + (float,)):
- self.fail('invalid')
+ if isinstance(data, bool) or not isinstance(
+ data, six.string_types + six.integer_types + (float,)
+ ):
+ self.fail("invalid")
value = six.text_type(data)
return value.strip() if self.trim_whitespace else value
@@ -806,66 +893,69 @@ class CharField(Field):
class EmailField(CharField):
- default_error_messages = {
- 'invalid': _('Enter a valid email address.')
- }
+ default_error_messages = {"invalid": _("Enter a valid email address.")}
def __init__(self, **kwargs):
super(EmailField, self).__init__(**kwargs)
- validator = EmailValidator(message=self.error_messages['invalid'])
+ validator = EmailValidator(message=self.error_messages["invalid"])
self.validators.append(validator)
class RegexField(CharField):
default_error_messages = {
- 'invalid': _('This value does not match the required pattern.')
+ "invalid": _("This value does not match the required pattern.")
}
def __init__(self, regex, **kwargs):
super(RegexField, self).__init__(**kwargs)
- validator = RegexValidator(regex, message=self.error_messages['invalid'])
+ validator = RegexValidator(regex, message=self.error_messages["invalid"])
self.validators.append(validator)
class SlugField(CharField):
default_error_messages = {
- 'invalid': _('Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'),
- 'invalid_unicode': _('Enter a valid "slug" consisting of Unicode letters, numbers, underscores, or hyphens.')
+ "invalid": _(
+ 'Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'
+ ),
+ "invalid_unicode": _(
+ 'Enter a valid "slug" consisting of Unicode letters, numbers, underscores, or hyphens.'
+ ),
}
def __init__(self, allow_unicode=False, **kwargs):
super(SlugField, self).__init__(**kwargs)
self.allow_unicode = allow_unicode
if self.allow_unicode:
- validator = RegexValidator(re.compile(r'^[-\w]+\Z', re.UNICODE), message=self.error_messages['invalid_unicode'])
+ validator = RegexValidator(
+ re.compile(r"^[-\w]+\Z", re.UNICODE),
+ message=self.error_messages["invalid_unicode"],
+ )
else:
- validator = RegexValidator(re.compile(r'^[-a-zA-Z0-9_]+$'), message=self.error_messages['invalid'])
+ validator = RegexValidator(
+ re.compile(r"^[-a-zA-Z0-9_]+$"), message=self.error_messages["invalid"]
+ )
self.validators.append(validator)
class URLField(CharField):
- default_error_messages = {
- 'invalid': _('Enter a valid URL.')
- }
+ default_error_messages = {"invalid": _("Enter a valid URL.")}
def __init__(self, **kwargs):
super(URLField, self).__init__(**kwargs)
- validator = URLValidator(message=self.error_messages['invalid'])
+ validator = URLValidator(message=self.error_messages["invalid"])
self.validators.append(validator)
class UUIDField(Field):
- valid_formats = ('hex_verbose', 'hex', 'int', 'urn')
+ valid_formats = ("hex_verbose", "hex", "int", "urn")
- default_error_messages = {
- 'invalid': _('Must be a valid UUID.'),
- }
+ default_error_messages = {"invalid": _("Must be a valid UUID.")}
def __init__(self, **kwargs):
- self.uuid_format = kwargs.pop('format', 'hex_verbose')
+ self.uuid_format = kwargs.pop("format", "hex_verbose")
if self.uuid_format not in self.valid_formats:
raise ValueError(
- 'Invalid format for uuid representation. '
+ "Invalid format for uuid representation. "
'Must be one of "{0}"'.format('", "'.join(self.valid_formats))
)
super(UUIDField, self).__init__(**kwargs)
@@ -878,13 +968,13 @@ class UUIDField(Field):
elif isinstance(data, six.string_types):
return uuid.UUID(hex=data)
else:
- self.fail('invalid', value=data)
+ self.fail("invalid", value=data)
except (ValueError):
- self.fail('invalid', value=data)
+ self.fail("invalid", value=data)
return data
def to_representation(self, value):
- if self.uuid_format == 'hex_verbose':
+ if self.uuid_format == "hex_verbose":
return str(value)
else:
return getattr(value, self.uuid_format)
@@ -893,68 +983,65 @@ class UUIDField(Field):
class IPAddressField(CharField):
"""Support both IPAddressField and GenericIPAddressField"""
- default_error_messages = {
- 'invalid': _('Enter a valid IPv4 or IPv6 address.'),
- }
+ default_error_messages = {"invalid": _("Enter a valid IPv4 or IPv6 address.")}
- def __init__(self, protocol='both', **kwargs):
+ def __init__(self, protocol="both", **kwargs):
self.protocol = protocol.lower()
- self.unpack_ipv4 = (self.protocol == 'both')
+ self.unpack_ipv4 = self.protocol == "both"
super(IPAddressField, self).__init__(**kwargs)
validators, error_message = ip_address_validators(protocol, self.unpack_ipv4)
self.validators.extend(validators)
def to_internal_value(self, data):
if not isinstance(data, six.string_types):
- self.fail('invalid', value=data)
+ self.fail("invalid", value=data)
- if ':' in data:
+ if ":" in data:
try:
- if self.protocol in ('both', 'ipv6'):
+ if self.protocol in ("both", "ipv6"):
return clean_ipv6_address(data, self.unpack_ipv4)
except DjangoValidationError:
- self.fail('invalid', value=data)
+ self.fail("invalid", value=data)
return super(IPAddressField, self).to_internal_value(data)
# Number types...
+
class IntegerField(Field):
default_error_messages = {
- 'invalid': _('A valid integer is required.'),
- 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
- 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
- 'max_string_length': _('String value too large.')
+ "invalid": _("A valid integer is required."),
+ "max_value": _("Ensure this value is less than or equal to {max_value}."),
+ "min_value": _("Ensure this value is greater than or equal to {min_value}."),
+ "max_string_length": _("String value too large."),
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
- re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2'
+ re_decimal = re.compile(r"\.0*\s*$") # allow e.g. '1.0' as an int, but not '1.2'
def __init__(self, **kwargs):
- self.max_value = kwargs.pop('max_value', None)
- self.min_value = kwargs.pop('min_value', None)
+ self.max_value = kwargs.pop("max_value", None)
+ self.min_value = kwargs.pop("min_value", None)
super(IntegerField, self).__init__(**kwargs)
if self.max_value is not None:
- message = lazy(
- self.error_messages['max_value'].format,
- six.text_type)(max_value=self.max_value)
- self.validators.append(
- MaxValueValidator(self.max_value, message=message))
+ message = lazy(self.error_messages["max_value"].format, six.text_type)(
+ max_value=self.max_value
+ )
+ self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
- message = lazy(
- self.error_messages['min_value'].format,
- six.text_type)(min_value=self.min_value)
- self.validators.append(
- MinValueValidator(self.min_value, message=message))
+ message = lazy(self.error_messages["min_value"].format, six.text_type)(
+ min_value=self.min_value
+ )
+ self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, data):
if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
- self.fail('max_string_length')
+ self.fail("max_string_length")
try:
- data = int(self.re_decimal.sub('', str(data)))
+ data = int(self.re_decimal.sub("", str(data)))
except (ValueError, TypeError):
- self.fail('invalid')
+ self.fail("invalid")
return data
def to_representation(self, value):
@@ -963,39 +1050,37 @@ class IntegerField(Field):
class FloatField(Field):
default_error_messages = {
- 'invalid': _('A valid number is required.'),
- 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
- 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
- 'max_string_length': _('String value too large.')
+ "invalid": _("A valid number is required."),
+ "max_value": _("Ensure this value is less than or equal to {max_value}."),
+ "min_value": _("Ensure this value is greater than or equal to {min_value}."),
+ "max_string_length": _("String value too large."),
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
def __init__(self, **kwargs):
- self.max_value = kwargs.pop('max_value', None)
- self.min_value = kwargs.pop('min_value', None)
+ self.max_value = kwargs.pop("max_value", None)
+ self.min_value = kwargs.pop("min_value", None)
super(FloatField, self).__init__(**kwargs)
if self.max_value is not None:
- message = lazy(
- self.error_messages['max_value'].format,
- six.text_type)(max_value=self.max_value)
- self.validators.append(
- MaxValueValidator(self.max_value, message=message))
+ message = lazy(self.error_messages["max_value"].format, six.text_type)(
+ max_value=self.max_value
+ )
+ self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
- message = lazy(
- self.error_messages['min_value'].format,
- six.text_type)(min_value=self.min_value)
- self.validators.append(
- MinValueValidator(self.min_value, message=message))
+ message = lazy(self.error_messages["min_value"].format, six.text_type)(
+ min_value=self.min_value
+ )
+ self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, data):
if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH:
- self.fail('max_string_length')
+ self.fail("max_string_length")
try:
return float(data)
except (TypeError, ValueError):
- self.fail('invalid')
+ self.fail("invalid")
def to_representation(self, value):
return float(value)
@@ -1003,18 +1088,33 @@ class FloatField(Field):
class DecimalField(Field):
default_error_messages = {
- 'invalid': _('A valid number is required.'),
- 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
- 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
- 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
- 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
- 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'),
- 'max_string_length': _('String value too large.')
+ "invalid": _("A valid number is required."),
+ "max_value": _("Ensure this value is less than or equal to {max_value}."),
+ "min_value": _("Ensure this value is greater than or equal to {min_value}."),
+ "max_digits": _(
+ "Ensure that there are no more than {max_digits} digits in total."
+ ),
+ "max_decimal_places": _(
+ "Ensure that there are no more than {max_decimal_places} decimal places."
+ ),
+ "max_whole_digits": _(
+ "Ensure that there are no more than {max_whole_digits} digits before the decimal point."
+ ),
+ "max_string_length": _("String value too large."),
}
MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs.
- def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None,
- localize=False, rounding=None, **kwargs):
+ def __init__(
+ self,
+ max_digits,
+ decimal_places,
+ coerce_to_string=None,
+ max_value=None,
+ min_value=None,
+ localize=False,
+ rounding=None,
+ **kwargs
+ ):
self.max_digits = max_digits
self.decimal_places = decimal_places
self.localize = localize
@@ -1034,22 +1134,24 @@ class DecimalField(Field):
super(DecimalField, self).__init__(**kwargs)
if self.max_value is not None:
- message = lazy(
- self.error_messages['max_value'].format,
- six.text_type)(max_value=self.max_value)
- self.validators.append(
- MaxValueValidator(self.max_value, message=message))
+ message = lazy(self.error_messages["max_value"].format, six.text_type)(
+ max_value=self.max_value
+ )
+ self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
- message = lazy(
- self.error_messages['min_value'].format,
- six.text_type)(min_value=self.min_value)
- self.validators.append(
- MinValueValidator(self.min_value, message=message))
+ message = lazy(self.error_messages["min_value"].format, six.text_type)(
+ min_value=self.min_value
+ )
+ self.validators.append(MinValueValidator(self.min_value, message=message))
if rounding is not None:
- valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')]
+ valid_roundings = [
+ v for k, v in vars(decimal).items() if k.startswith("ROUND_")
+ ]
assert rounding in valid_roundings, (
- 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings))
+ "Invalid rounding option %s. Valid values for rounding are: %s"
+ % (rounding, valid_roundings)
+ )
self.rounding = rounding
def to_internal_value(self, data):
@@ -1064,21 +1166,21 @@ class DecimalField(Field):
data = sanitize_separators(data)
if len(data) > self.MAX_STRING_LENGTH:
- self.fail('max_string_length')
+ self.fail("max_string_length")
try:
value = decimal.Decimal(data)
except decimal.DecimalException:
- self.fail('invalid')
+ self.fail("invalid")
# Check for NaN. It is the only value that isn't equal to itself,
# so we can use this to identify NaN values.
if value != value:
- self.fail('invalid')
+ self.fail("invalid")
# Check for infinity and negative infinity.
- if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
- self.fail('invalid')
+ if value in (decimal.Decimal("Inf"), decimal.Decimal("-Inf")):
+ self.fail("invalid")
return self.quantize(self.validate_precision(value))
@@ -1109,16 +1211,18 @@ class DecimalField(Field):
decimal_places = total_digits
if self.max_digits is not None and total_digits > self.max_digits:
- self.fail('max_digits', max_digits=self.max_digits)
+ self.fail("max_digits", max_digits=self.max_digits)
if self.decimal_places is not None and decimal_places > self.decimal_places:
- self.fail('max_decimal_places', max_decimal_places=self.decimal_places)
+ self.fail("max_decimal_places", max_decimal_places=self.decimal_places)
if self.max_whole_digits is not None and whole_digits > self.max_whole_digits:
- self.fail('max_whole_digits', max_whole_digits=self.max_whole_digits)
+ self.fail("max_whole_digits", max_whole_digits=self.max_whole_digits)
return value
def to_representation(self, value):
- coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING)
+ coerce_to_string = getattr(
+ self, "coerce_to_string", api_settings.COERCE_DECIMAL_TO_STRING
+ )
if not isinstance(value, decimal.Decimal):
value = decimal.Decimal(six.text_type(value).strip())
@@ -1130,7 +1234,7 @@ class DecimalField(Field):
if self.localize:
return localize_input(quantized)
- return '{0:f}'.format(quantized)
+ return "{0:f}".format(quantized)
def quantize(self, value):
"""
@@ -1143,24 +1247,29 @@ class DecimalField(Field):
if self.max_digits is not None:
context.prec = self.max_digits
return value.quantize(
- decimal.Decimal('.1') ** self.decimal_places,
+ decimal.Decimal(".1") ** self.decimal_places,
rounding=self.rounding,
- context=context
+ context=context,
)
# Date & time fields...
+
class DateTimeField(Field):
default_error_messages = {
- 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'),
- 'date': _('Expected a datetime but got a date.'),
- 'make_aware': _('Invalid datetime for the timezone "{timezone}".'),
- 'overflow': _('Datetime value out of range.')
+ "invalid": _(
+ "Datetime has wrong format. Use one of these formats instead: {format}."
+ ),
+ "date": _("Expected a datetime but got a date."),
+ "make_aware": _('Invalid datetime for the timezone "{timezone}".'),
+ "overflow": _("Datetime value out of range."),
}
datetime_parser = datetime.datetime.strptime
- def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs):
+ def __init__(
+ self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs
+ ):
if format is not empty:
self.format = format
if input_formats is not None:
@@ -1174,18 +1283,18 @@ class DateTimeField(Field):
When `self.default_timezone` is `None`, always return naive datetimes.
When `self.default_timezone` is not `None`, always return aware datetimes.
"""
- field_timezone = getattr(self, 'timezone', self.default_timezone())
+ field_timezone = getattr(self, "timezone", self.default_timezone())
if field_timezone is not None:
if timezone.is_aware(value):
try:
return value.astimezone(field_timezone)
except OverflowError:
- self.fail('overflow')
+ self.fail("overflow")
try:
return timezone.make_aware(value, field_timezone)
except InvalidTimeError:
- self.fail('make_aware', timezone=field_timezone)
+ self.fail("make_aware", timezone=field_timezone)
elif (field_timezone is None) and timezone.is_aware(value):
return timezone.make_naive(value, utc)
return value
@@ -1194,10 +1303,14 @@ class DateTimeField(Field):
return timezone.get_current_timezone() if settings.USE_TZ else None
def to_internal_value(self, value):
- input_formats = getattr(self, 'input_formats', api_settings.DATETIME_INPUT_FORMATS)
+ input_formats = getattr(
+ self, "input_formats", api_settings.DATETIME_INPUT_FORMATS
+ )
- if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime):
- self.fail('date')
+ if isinstance(value, datetime.date) and not isinstance(
+ value, datetime.datetime
+ ):
+ self.fail("date")
if isinstance(value, datetime.datetime):
return self.enforce_timezone(value)
@@ -1218,13 +1331,13 @@ class DateTimeField(Field):
pass
humanized_format = humanize_datetime.datetime_formats(input_formats)
- self.fail('invalid', format=humanized_format)
+ self.fail("invalid", format=humanized_format)
def to_representation(self, value):
if not value:
return None
- output_format = getattr(self, 'format', api_settings.DATETIME_FORMAT)
+ output_format = getattr(self, "format", api_settings.DATETIME_FORMAT)
if output_format is None or isinstance(value, six.string_types):
return value
@@ -1233,16 +1346,18 @@ class DateTimeField(Field):
if output_format.lower() == ISO_8601:
value = value.isoformat()
- if value.endswith('+00:00'):
- value = value[:-6] + 'Z'
+ if value.endswith("+00:00"):
+ value = value[:-6] + "Z"
return value
return value.strftime(output_format)
class DateField(Field):
default_error_messages = {
- 'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'),
- 'datetime': _('Expected a date but got a datetime.'),
+ "invalid": _(
+ "Date has wrong format. Use one of these formats instead: {format}."
+ ),
+ "datetime": _("Expected a date but got a datetime."),
}
datetime_parser = datetime.datetime.strptime
@@ -1254,10 +1369,10 @@ class DateField(Field):
super(DateField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- input_formats = getattr(self, 'input_formats', api_settings.DATE_INPUT_FORMATS)
+ input_formats = getattr(self, "input_formats", api_settings.DATE_INPUT_FORMATS)
if isinstance(value, datetime.datetime):
- self.fail('datetime')
+ self.fail("datetime")
if isinstance(value, datetime.date):
return value
@@ -1280,13 +1395,13 @@ class DateField(Field):
return parsed.date()
humanized_format = humanize_datetime.date_formats(input_formats)
- self.fail('invalid', format=humanized_format)
+ self.fail("invalid", format=humanized_format)
def to_representation(self, value):
if not value:
return None
- output_format = getattr(self, 'format', api_settings.DATE_FORMAT)
+ output_format = getattr(self, "format", api_settings.DATE_FORMAT)
if output_format is None or isinstance(value, six.string_types):
return value
@@ -1295,9 +1410,9 @@ class DateField(Field):
# not a sensible thing to do, as it means naively dropping
# any explicit or implicit timezone info.
assert not isinstance(value, datetime.datetime), (
- 'Expected a `date`, but got a `datetime`. Refusing to coerce, '
- 'as this may mean losing timezone information. Use a custom '
- 'read-only field and deal with timezone issues explicitly.'
+ "Expected a `date`, but got a `datetime`. Refusing to coerce, "
+ "as this may mean losing timezone information. Use a custom "
+ "read-only field and deal with timezone issues explicitly."
)
if output_format.lower() == ISO_8601:
@@ -1308,7 +1423,9 @@ class DateField(Field):
class TimeField(Field):
default_error_messages = {
- 'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'),
+ "invalid": _(
+ "Time has wrong format. Use one of these formats instead: {format}."
+ )
}
datetime_parser = datetime.datetime.strptime
@@ -1320,7 +1437,7 @@ class TimeField(Field):
super(TimeField, self).__init__(*args, **kwargs)
def to_internal_value(self, value):
- input_formats = getattr(self, 'input_formats', api_settings.TIME_INPUT_FORMATS)
+ input_formats = getattr(self, "input_formats", api_settings.TIME_INPUT_FORMATS)
if isinstance(value, datetime.time):
return value
@@ -1343,13 +1460,13 @@ class TimeField(Field):
return parsed.time()
humanized_format = humanize_datetime.time_formats(input_formats)
- self.fail('invalid', format=humanized_format)
+ self.fail("invalid", format=humanized_format)
def to_representation(self, value):
- if value in (None, ''):
+ if value in (None, ""):
return None
- output_format = getattr(self, 'format', api_settings.TIME_FORMAT)
+ output_format = getattr(self, "format", api_settings.TIME_FORMAT)
if output_format is None or isinstance(value, six.string_types):
return value
@@ -1358,9 +1475,9 @@ class TimeField(Field):
# not a sensible thing to do, as it means naively dropping
# any explicit or implicit timezone info.
assert not isinstance(value, datetime.datetime), (
- 'Expected a `time`, but got a `datetime`. Refusing to coerce, '
- 'as this may mean losing timezone information. Use a custom '
- 'read-only field and deal with timezone issues explicitly.'
+ "Expected a `time`, but got a `datetime`. Refusing to coerce, "
+ "as this may mean losing timezone information. Use a custom "
+ "read-only field and deal with timezone issues explicitly."
)
if output_format.lower() == ISO_8601:
@@ -1370,27 +1487,27 @@ class TimeField(Field):
class DurationField(Field):
default_error_messages = {
- 'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'),
- 'max_value': _('Ensure this value is less than or equal to {max_value}.'),
- 'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
+ "invalid": _(
+ "Duration has wrong format. Use one of these formats instead: {format}."
+ ),
+ "max_value": _("Ensure this value is less than or equal to {max_value}."),
+ "min_value": _("Ensure this value is greater than or equal to {min_value}."),
}
def __init__(self, **kwargs):
- self.max_value = kwargs.pop('max_value', None)
- self.min_value = kwargs.pop('min_value', None)
+ self.max_value = kwargs.pop("max_value", None)
+ self.min_value = kwargs.pop("min_value", None)
super(DurationField, self).__init__(**kwargs)
if self.max_value is not None:
- message = lazy(
- self.error_messages['max_value'].format,
- six.text_type)(max_value=self.max_value)
- self.validators.append(
- MaxValueValidator(self.max_value, message=message))
+ message = lazy(self.error_messages["max_value"].format, six.text_type)(
+ max_value=self.max_value
+ )
+ self.validators.append(MaxValueValidator(self.max_value, message=message))
if self.min_value is not None:
- message = lazy(
- self.error_messages['min_value'].format,
- six.text_type)(min_value=self.min_value)
- self.validators.append(
- MinValueValidator(self.min_value, message=message))
+ message = lazy(self.error_messages["min_value"].format, six.text_type)(
+ min_value=self.min_value
+ )
+ self.validators.append(MinValueValidator(self.min_value, message=message))
def to_internal_value(self, value):
if isinstance(value, datetime.timedelta):
@@ -1398,7 +1515,7 @@ class DurationField(Field):
parsed = parse_duration(six.text_type(value))
if parsed is not None:
return parsed
- self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]')
+ self.fail("invalid", format="[DD] [HH:[MM:]]ss[.uuuuuu]")
def to_representation(self, value):
return duration_string(value)
@@ -1406,33 +1523,32 @@ class DurationField(Field):
# Choice types...
+
class ChoiceField(Field):
- default_error_messages = {
- 'invalid_choice': _('"{input}" is not a valid choice.')
- }
+ default_error_messages = {"invalid_choice": _('"{input}" is not a valid choice.')}
html_cutoff = None
- html_cutoff_text = _('More than {count} items...')
+ html_cutoff_text = _("More than {count} items...")
def __init__(self, choices, **kwargs):
self.choices = choices
- self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff)
- self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text)
+ self.html_cutoff = kwargs.pop("html_cutoff", self.html_cutoff)
+ self.html_cutoff_text = kwargs.pop("html_cutoff_text", self.html_cutoff_text)
- self.allow_blank = kwargs.pop('allow_blank', False)
+ self.allow_blank = kwargs.pop("allow_blank", False)
super(ChoiceField, self).__init__(**kwargs)
def to_internal_value(self, data):
- if data == '' and self.allow_blank:
- return ''
+ if data == "" and self.allow_blank:
+ return ""
try:
return self.choice_strings_to_values[six.text_type(data)]
except KeyError:
- self.fail('invalid_choice', input=data)
+ self.fail("invalid_choice", input=data)
def to_representation(self, value):
- if value in ('', None):
+ if value in ("", None):
return value
return self.choice_strings_to_values.get(six.text_type(value), value)
@@ -1443,7 +1559,7 @@ class ChoiceField(Field):
return iter_options(
self.grouped_choices,
cutoff=self.html_cutoff,
- cutoff_text=self.html_cutoff_text
+ cutoff_text=self.html_cutoff_text,
)
def _get_choices(self):
@@ -1465,19 +1581,19 @@ class ChoiceField(Field):
class MultipleChoiceField(ChoiceField):
default_error_messages = {
- 'invalid_choice': _('"{input}" is not a valid choice.'),
- 'not_a_list': _('Expected a list of items but got type "{input_type}".'),
- 'empty': _('This selection may not be empty.')
+ "invalid_choice": _('"{input}" is not a valid choice.'),
+ "not_a_list": _('Expected a list of items but got type "{input_type}".'),
+ "empty": _("This selection may not be empty."),
}
default_empty_html = []
def __init__(self, *args, **kwargs):
- self.allow_empty = kwargs.pop('allow_empty', True)
+ self.allow_empty = kwargs.pop("allow_empty", True)
super(MultipleChoiceField, self).__init__(*args, **kwargs)
def get_value(self, dictionary):
if self.field_name not in dictionary:
- if getattr(self.root, 'partial', False):
+ if getattr(self.root, "partial", False):
return empty
# We override the default field access in order to support
# lists in HTML forms.
@@ -1486,55 +1602,72 @@ class MultipleChoiceField(ChoiceField):
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
- if isinstance(data, six.text_type) or not hasattr(data, '__iter__'):
- self.fail('not_a_list', input_type=type(data).__name__)
+ if isinstance(data, six.text_type) or not hasattr(data, "__iter__"):
+ self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
- self.fail('empty')
+ self.fail("empty")
return {
- super(MultipleChoiceField, self).to_internal_value(item)
- for item in data
+ super(MultipleChoiceField, self).to_internal_value(item) for item in data
}
def to_representation(self, value):
return {
- self.choice_strings_to_values.get(six.text_type(item), item) for item in value
+ self.choice_strings_to_values.get(six.text_type(item), item)
+ for item in value
}
class FilePathField(ChoiceField):
default_error_messages = {
- 'invalid_choice': _('"{input}" is not a valid path choice.')
+ "invalid_choice": _('"{input}" is not a valid path choice.')
}
- def __init__(self, path, match=None, recursive=False, allow_files=True,
- allow_folders=False, required=None, **kwargs):
+ def __init__(
+ self,
+ path,
+ match=None,
+ recursive=False,
+ allow_files=True,
+ allow_folders=False,
+ required=None,
+ **kwargs
+ ):
# Defer to Django's FilePathField implementation to get the
# valid set of choices.
field = DjangoFilePathField(
- path, match=match, recursive=recursive, allow_files=allow_files,
- allow_folders=allow_folders, required=required
+ path,
+ match=match,
+ recursive=recursive,
+ allow_files=allow_files,
+ allow_folders=allow_folders,
+ required=required,
)
- kwargs['choices'] = field.choices
+ kwargs["choices"] = field.choices
super(FilePathField, self).__init__(**kwargs)
# File types...
+
class FileField(Field):
default_error_messages = {
- 'required': _('No file was submitted.'),
- 'invalid': _('The submitted data was not a file. Check the encoding type on the form.'),
- 'no_name': _('No filename could be determined.'),
- 'empty': _('The submitted file is empty.'),
- 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'),
+ "required": _("No file was submitted."),
+ "invalid": _(
+ "The submitted data was not a file. Check the encoding type on the form."
+ ),
+ "no_name": _("No filename could be determined."),
+ "empty": _("The submitted file is empty."),
+ "max_length": _(
+ "Ensure this filename has at most {max_length} characters (it has {length})."
+ ),
}
def __init__(self, *args, **kwargs):
- self.max_length = kwargs.pop('max_length', None)
- self.allow_empty_file = kwargs.pop('allow_empty_file', False)
- if 'use_url' in kwargs:
- self.use_url = kwargs.pop('use_url')
+ self.max_length = kwargs.pop("max_length", None)
+ self.allow_empty_file = kwargs.pop("allow_empty_file", False)
+ if "use_url" in kwargs:
+ self.use_url = kwargs.pop("use_url")
super(FileField, self).__init__(*args, **kwargs)
def to_internal_value(self, data):
@@ -1543,14 +1676,14 @@ class FileField(Field):
file_name = data.name
file_size = data.size
except AttributeError:
- self.fail('invalid')
+ self.fail("invalid")
if not file_name:
- self.fail('no_name')
+ self.fail("no_name")
if not self.allow_empty_file and not file_size:
- self.fail('empty')
+ self.fail("empty")
if self.max_length and len(file_name) > self.max_length:
- self.fail('max_length', max_length=self.max_length, length=len(file_name))
+ self.fail("max_length", max_length=self.max_length, length=len(file_name))
return data
@@ -1558,14 +1691,14 @@ class FileField(Field):
if not value:
return None
- use_url = getattr(self, 'use_url', api_settings.UPLOADED_FILES_USE_URL)
+ use_url = getattr(self, "use_url", api_settings.UPLOADED_FILES_USE_URL)
if use_url:
- if not getattr(value, 'url', None):
+ if not getattr(value, "url", None):
# If the file has not been saved it may not have a URL.
return None
url = value.url
- request = self.context.get('request', None)
+ request = self.context.get("request", None)
if request is not None:
return request.build_absolute_uri(url)
return url
@@ -1574,13 +1707,13 @@ class FileField(Field):
class ImageField(FileField):
default_error_messages = {
- 'invalid_image': _(
- 'Upload a valid image. The file you uploaded was either not an image or a corrupted image.'
- ),
+ "invalid_image": _(
+ "Upload a valid image. The file you uploaded was either not an image or a corrupted image."
+ )
}
def __init__(self, *args, **kwargs):
- self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField)
+ self._DjangoImageField = kwargs.pop("_DjangoImageField", DjangoImageField)
super(ImageField, self).__init__(*args, **kwargs)
def to_internal_value(self, data):
@@ -1595,6 +1728,7 @@ class ImageField(FileField):
# Composite field types...
+
class _UnvalidatedField(Field):
def __init__(self, *args, **kwargs):
super(_UnvalidatedField, self).__init__(*args, **kwargs)
@@ -1612,36 +1746,40 @@ class ListField(Field):
child = _UnvalidatedField()
initial = []
default_error_messages = {
- 'not_a_list': _('Expected a list of items but got type "{input_type}".'),
- 'empty': _('This list may not be empty.'),
- 'min_length': _('Ensure this field has at least {min_length} elements.'),
- 'max_length': _('Ensure this field has no more than {max_length} elements.')
+ "not_a_list": _('Expected a list of items but got type "{input_type}".'),
+ "empty": _("This list may not be empty."),
+ "min_length": _("Ensure this field has at least {min_length} elements."),
+ "max_length": _("Ensure this field has no more than {max_length} elements."),
}
def __init__(self, *args, **kwargs):
- self.child = kwargs.pop('child', copy.deepcopy(self.child))
- self.allow_empty = kwargs.pop('allow_empty', True)
- self.max_length = kwargs.pop('max_length', None)
- self.min_length = kwargs.pop('min_length', None)
+ self.child = kwargs.pop("child", copy.deepcopy(self.child))
+ self.allow_empty = kwargs.pop("allow_empty", True)
+ self.max_length = kwargs.pop("max_length", None)
+ self.min_length = kwargs.pop("min_length", None)
- assert not inspect.isclass(self.child), '`child` has not been instantiated.'
+ assert not inspect.isclass(self.child), "`child` has not been instantiated."
assert self.child.source is None, (
"The `source` argument is not meaningful when applied to a `child=` field. "
"Remove `source=` from the field declaration."
)
super(ListField, self).__init__(*args, **kwargs)
- self.child.bind(field_name='', parent=self)
+ self.child.bind(field_name="", parent=self)
if self.max_length is not None:
- message = self.error_messages['max_length'].format(max_length=self.max_length)
+ message = self.error_messages["max_length"].format(
+ max_length=self.max_length
+ )
self.validators.append(MaxLengthValidator(self.max_length, message=message))
if self.min_length is not None:
- message = self.error_messages['min_length'].format(min_length=self.min_length)
+ message = self.error_messages["min_length"].format(
+ min_length=self.min_length
+ )
self.validators.append(MinLengthValidator(self.min_length, message=message))
def get_value(self, dictionary):
if self.field_name not in dictionary:
- if getattr(self.root, 'partial', False):
+ if getattr(self.root, "partial", False):
return empty
# We override the default field access in order to support
# lists in HTML forms.
@@ -1650,7 +1788,9 @@ class ListField(Field):
if len(val) > 0:
# Support QueryDict lists in HTML input.
return val
- return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
+ return html.parse_html_list(
+ dictionary, prefix=self.field_name, default=empty
+ )
return dictionary.get(self.field_name, empty)
@@ -1660,17 +1800,20 @@ class ListField(Field):
"""
if html.is_html_input(data):
data = html.parse_html_list(data, default=[])
- if isinstance(data, (six.text_type, Mapping)) or not hasattr(data, '__iter__'):
- self.fail('not_a_list', input_type=type(data).__name__)
+ if isinstance(data, (six.text_type, Mapping)) or not hasattr(data, "__iter__"):
+ self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
- self.fail('empty')
+ self.fail("empty")
return self.run_child_validation(data)
def to_representation(self, data):
"""
List of object instances -> List of dicts of primitive datatypes.
"""
- return [self.child.to_representation(item) if item is not None else None for item in data]
+ return [
+ self.child.to_representation(item) if item is not None else None
+ for item in data
+ ]
def run_child_validation(self, data):
result = []
@@ -1691,20 +1834,20 @@ class DictField(Field):
child = _UnvalidatedField()
initial = {}
default_error_messages = {
- 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".')
+ "not_a_dict": _('Expected a dictionary of items but got type "{input_type}".')
}
def __init__(self, *args, **kwargs):
- self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ self.child = kwargs.pop("child", copy.deepcopy(self.child))
- assert not inspect.isclass(self.child), '`child` has not been instantiated.'
+ assert not inspect.isclass(self.child), "`child` has not been instantiated."
assert self.child.source is None, (
"The `source` argument is not meaningful when applied to a `child=` field. "
"Remove `source=` from the field declaration."
)
super(DictField, self).__init__(*args, **kwargs)
- self.child.bind(field_name='', parent=self)
+ self.child.bind(field_name="", parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
@@ -1720,12 +1863,14 @@ class DictField(Field):
if html.is_html_input(data):
data = html.parse_html_dict(data)
if not isinstance(data, dict):
- self.fail('not_a_dict', input_type=type(data).__name__)
+ self.fail("not_a_dict", input_type=type(data).__name__)
return self.run_child_validation(data)
def to_representation(self, value):
return {
- six.text_type(key): self.child.to_representation(val) if val is not None else None
+ six.text_type(key): self.child.to_representation(val)
+ if val is not None
+ else None
for key, val in value.items()
}
@@ -1758,12 +1903,10 @@ class HStoreField(DictField):
class JSONField(Field):
- default_error_messages = {
- 'invalid': _('Value must be valid JSON.')
- }
+ default_error_messages = {"invalid": _("Value must be valid JSON.")}
def __init__(self, *args, **kwargs):
- self.binary = kwargs.pop('binary', False)
+ self.binary = kwargs.pop("binary", False)
super(JSONField, self).__init__(*args, **kwargs)
def get_value(self, dictionary):
@@ -1775,19 +1918,20 @@ class JSONField(Field):
ret = six.text_type.__new__(self, value)
ret.is_json_string = True
return ret
+
return JSONString(dictionary[self.field_name])
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
try:
- if self.binary or getattr(data, 'is_json_string', False):
+ if self.binary or getattr(data, "is_json_string", False):
if isinstance(data, bytes):
- data = data.decode('utf-8')
+ data = data.decode("utf-8")
return json.loads(data)
else:
json.dumps(data)
except (TypeError, ValueError):
- self.fail('invalid')
+ self.fail("invalid")
return data
def to_representation(self, value):
@@ -1796,12 +1940,13 @@ class JSONField(Field):
# On python 2.x the return type for json.dumps() is underspecified.
# On python 3.x json.dumps() returns unicode strings.
if isinstance(value, six.text_type):
- value = bytes(value.encode('utf-8'))
+ value = bytes(value.encode("utf-8"))
return value
# Miscellaneous field types...
+
class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
@@ -1816,7 +1961,7 @@ class ReadOnlyField(Field):
"""
def __init__(self, **kwargs):
- kwargs['read_only'] = True
+ kwargs["read_only"] = True
super(ReadOnlyField, self).__init__(**kwargs)
def to_representation(self, value):
@@ -1831,9 +1976,10 @@ class HiddenField(Field):
constraint on a pair of fields, as we need some way to include the date in
the validated data.
"""
+
def __init__(self, **kwargs):
- assert 'default' in kwargs, 'default is a required argument.'
- kwargs['write_only'] = True
+ assert "default" in kwargs, "default is a required argument."
+ kwargs["write_only"] = True
super(HiddenField, self).__init__(**kwargs)
def get_value(self, dictionary):
@@ -1860,22 +2006,23 @@ class SerializerMethodField(Field):
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
+
def __init__(self, method_name=None, **kwargs):
self.method_name = method_name
- kwargs['source'] = '*'
- kwargs['read_only'] = True
+ kwargs["source"] = "*"
+ kwargs["read_only"] = True
super(SerializerMethodField, self).__init__(**kwargs)
def bind(self, field_name, parent):
# In order to enforce a consistent style, we error if a redundant
# 'method_name' argument has been used. For example:
# my_field = serializer.SerializerMethodField(method_name='get_my_field')
- default_method_name = 'get_{field_name}'.format(field_name=field_name)
+ default_method_name = "get_{field_name}".format(field_name=field_name)
assert self.method_name != default_method_name, (
"It is redundant to specify `%s` on SerializerMethodField '%s' in "
"serializer '%s', because it is the same as the default method name. "
- "Remove the `method_name` argument." %
- (self.method_name, field_name, parent.__class__.__name__)
+ "Remove the `method_name` argument."
+ % (self.method_name, field_name, parent.__class__.__name__)
)
# The method name should default to `get_{field_name}`.
@@ -1896,22 +2043,22 @@ class ModelField(Field):
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
"""
+
default_error_messages = {
- 'max_length': _('Ensure this field has no more than {max_length} characters.'),
+ "max_length": _("Ensure this field has no more than {max_length} characters.")
}
def __init__(self, model_field, **kwargs):
self.model_field = model_field
# The `max_length` option is supported by Django's base `Field` class,
# so we'd better support it here.
- max_length = kwargs.pop('max_length', None)
+ max_length = kwargs.pop("max_length", None)
super(ModelField, self).__init__(**kwargs)
if max_length is not None:
- message = lazy(
- self.error_messages['max_length'].format,
- six.text_type)(max_length=self.max_length)
- self.validators.append(
- MaxLengthValidator(self.max_length, message=message))
+ message = lazy(self.error_messages["max_length"].format, six.text_type)(
+ max_length=self.max_length
+ )
+ self.validators.append(MaxLengthValidator(self.max_length, message=message))
def to_internal_value(self, data):
rel = self.model_field.remote_field
diff --git a/rest_framework/filters.py b/rest_framework/filters.py
index bb1b86586..34d7d6225 100644
--- a/rest_framework/filters.py
+++ b/rest_framework/filters.py
@@ -18,9 +18,7 @@ from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import RemovedInDRF310Warning
-from rest_framework.compat import (
- coreapi, coreschema, distinct, is_guardian_installed
-)
+from rest_framework.compat import coreapi, coreschema, distinct, is_guardian_installed
from rest_framework.settings import api_settings
@@ -36,23 +34,22 @@ class BaseFilterBackend(object):
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
- template = 'rest_framework/filters/search.html'
- lookup_prefixes = {
- '^': 'istartswith',
- '=': 'iexact',
- '@': 'search',
- '$': 'iregex',
- }
- search_title = _('Search')
- search_description = _('A search term.')
+ template = "rest_framework/filters/search.html"
+ lookup_prefixes = {"^": "istartswith", "=": "iexact", "@": "search", "$": "iregex"}
+ search_title = _("Search")
+ search_description = _("A search term.")
def get_search_fields(self, view, request):
"""
@@ -60,22 +57,22 @@ class SearchFilter(BaseFilterBackend):
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
"""
- return getattr(view, 'search_fields', None)
+ return getattr(view, "search_fields", None)
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
- params = request.query_params.get(self.search_param, '')
- return params.replace(',', ' ').split()
+ params = request.query_params.get(self.search_param, "")
+ return params.replace(",", " ").split()
def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
- lookup = 'icontains'
+ lookup = "icontains"
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
@@ -87,12 +84,15 @@ class SearchFilter(BaseFilterBackend):
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
- if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
+ if (
+ isinstance(queryset, models.QuerySet)
+ and search_field in queryset.query.annotations
+ ):
return False
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
- if hasattr(field, 'get_path_info'):
+ if hasattr(field, "get_path_info"):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
@@ -117,8 +117,7 @@ class SearchFilter(BaseFilterBackend):
conditions = []
for search_term in search_terms:
queries = [
- models.Q(**{orm_lookup: search_term})
- for orm_lookup in orm_lookups
+ models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups
]
conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
@@ -132,30 +131,31 @@ class SearchFilter(BaseFilterBackend):
return queryset
def to_html(self, request, queryset, view):
- if not getattr(view, 'search_fields', None):
- return ''
+ if not getattr(view, "search_fields", None):
+ return ""
term = self.get_search_terms(request)
- term = term[0] if term else ''
- context = {
- 'param': self.search_param,
- 'term': term
- }
+ term = term[0] if term else ""
+ context = {"param": self.search_param, "term": term}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.search_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.String(
title=force_text(self.search_title),
- description=force_text(self.search_description)
- )
+ description=force_text(self.search_description),
+ ),
)
]
@@ -164,9 +164,9 @@ class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
- ordering_title = _('Ordering')
- ordering_description = _('Which field to use when ordering the results.')
- template = 'rest_framework/filters/ordering.html'
+ ordering_title = _("Ordering")
+ ordering_description = _("Which field to use when ordering the results.")
+ template = "rest_framework/filters/ordering.html"
def get_ordering(self, request, queryset, view):
"""
@@ -178,7 +178,7 @@ class OrderingFilter(BaseFilterBackend):
"""
params = request.query_params.get(self.ordering_param)
if params:
- fields = [param.strip() for param in params.split(',')]
+ fields = [param.strip() for param in params.split(",")]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
@@ -187,7 +187,7 @@ class OrderingFilter(BaseFilterBackend):
return self.get_default_ordering(view)
def get_default_ordering(self, view):
- ordering = getattr(view, 'ordering', None)
+ ordering = getattr(view, "ordering", None)
if isinstance(ordering, six.string_types):
return (ordering,)
return ordering
@@ -195,7 +195,7 @@ class OrderingFilter(BaseFilterBackend):
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
- if hasattr(view, 'get_serializer_class'):
+ if hasattr(view, "get_serializer_class"):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
@@ -203,7 +203,7 @@ class OrderingFilter(BaseFilterBackend):
# no serializer_class was found
serializer_class = None
else:
- serializer_class = getattr(view, 'serializer_class', None)
+ serializer_class = getattr(view, "serializer_class", None)
if serializer_class is None:
msg = (
@@ -214,26 +214,26 @@ class OrderingFilter(BaseFilterBackend):
raise ImproperlyConfigured(msg % self.__class__.__name__)
return [
- (field.source.replace('.', '__') or field_name, field.label)
+ (field.source.replace(".", "__") or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
- if not getattr(field, 'write_only', False) and not field.source == '*'
+ if not getattr(field, "write_only", False) and not field.source == "*"
]
def get_valid_fields(self, queryset, view, context={}):
- valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
+ valid_fields = getattr(view, "ordering_fields", self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
- elif valid_fields == '__all__':
+ elif valid_fields == "__all__":
# View explicitly allows filtering on any model field
valid_fields = [
- (field.name, field.verbose_name) for field in queryset.model._meta.fields
+ (field.name, field.verbose_name)
+ for field in queryset.model._meta.fields
]
valid_fields += [
- (key, key.title().split('__'))
- for key in queryset.query.annotations
+ (key, key.title().split("__")) for key in queryset.query.annotations
]
else:
valid_fields = [
@@ -244,8 +244,15 @@ class OrderingFilter(BaseFilterBackend):
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
- valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
- return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)]
+ valid_fields = [
+ item[0]
+ for item in self.get_valid_fields(queryset, view, {"request": request})
+ ]
+ return [
+ term
+ for term in fields
+ if term.lstrip("-") in valid_fields and ORDER_PATTERN.match(term)
+ ]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
@@ -259,15 +266,11 @@ class OrderingFilter(BaseFilterBackend):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
- context = {
- 'request': request,
- 'current': current,
- 'param': self.ordering_param,
- }
+ context = {"request": request, "current": current, "param": self.ordering_param}
for key, label in self.get_valid_fields(queryset, view, context):
- options.append((key, '%s - %s' % (label, _('ascending'))))
- options.append(('-' + key, '%s - %s' % (label, _('descending'))))
- context['options'] = options
+ options.append((key, "%s - %s" % (label, _("ascending"))))
+ options.append(("-" + key, "%s - %s" % (label, _("descending"))))
+ context["options"] = options
return context
def to_html(self, request, queryset, view):
@@ -276,17 +279,21 @@ class OrderingFilter(BaseFilterBackend):
return template.render(context)
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.ordering_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.String(
title=force_text(self.ordering_title),
- description=force_text(self.ordering_description)
- )
+ description=force_text(self.ordering_description),
+ ),
)
]
@@ -296,15 +303,19 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
A filter backend that limits results to those where the requesting user
has read object level permissions.
"""
+
def __init__(self):
warnings.warn(
"`DjangoObjectPermissionsFilter` has been deprecated and moved to "
"the 3rd-party django-rest-framework-guardian package.",
- RemovedInDRF310Warning, stacklevel=2
+ RemovedInDRF310Warning,
+ stacklevel=2,
)
- assert is_guardian_installed(), 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed'
+ assert (
+ is_guardian_installed()
+ ), "Using DjangoObjectPermissionsFilter, but django-guardian is not installed"
- perm_format = '%(app_label)s.view_%(model_name)s'
+ perm_format = "%(app_label)s.view_%(model_name)s"
def filter_queryset(self, request, queryset, view):
# We want to defer this import until run-time, rather than import-time.
@@ -317,13 +328,13 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
user = request.user
model_cls = queryset.model
kwargs = {
- 'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.model_name
+ "app_label": model_cls._meta.app_label,
+ "model_name": model_cls._meta.model_name,
}
permission = self.perm_format % kwargs
if tuple(guardian_version) >= (1, 3):
# Maintain behavior compatibility with versions prior to 1.3
- extra = {'accept_global_perms': False}
+ extra = {"accept_global_perms": False}
else:
extra = {}
return get_objects_for_user(user, permission, queryset, **extra)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 8d0bf284a..e5e132422 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -27,6 +27,7 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
+
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
@@ -38,7 +39,7 @@ class GenericAPIView(views.APIView):
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
- lookup_field = 'pk'
+ lookup_field = "pk"
lookup_url_kwarg = None
# The filter backend classes to use for queryset filtering
@@ -64,8 +65,7 @@ class GenericAPIView(views.APIView):
"""
assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, "
- "or override the `get_queryset()` method."
- % self.__class__.__name__
+ "or override the `get_queryset()` method." % self.__class__.__name__
)
queryset = self.queryset
@@ -88,10 +88,10 @@ class GenericAPIView(views.APIView):
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
- 'Expected view %s to be called with a URL keyword argument '
+ "Expected view %s to be called with a URL keyword argument "
'named "%s". Fix your URL conf, or set the `.lookup_field` '
- 'attribute on the view correctly.' %
- (self.__class__.__name__, lookup_url_kwarg)
+ "attribute on the view correctly."
+ % (self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
@@ -108,7 +108,7 @@ class GenericAPIView(views.APIView):
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
- kwargs['context'] = self.get_serializer_context()
+ kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
@@ -123,8 +123,7 @@ class GenericAPIView(views.APIView):
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
- "or override the `get_serializer_class()` method."
- % self.__class__.__name__
+ "or override the `get_serializer_class()` method." % self.__class__.__name__
)
return self.serializer_class
@@ -133,11 +132,7 @@ class GenericAPIView(views.APIView):
"""
Extra context provided to the serializer class.
"""
- return {
- 'request': self.request,
- 'format': self.format_kwarg,
- 'view': self
- }
+ return {"request": self.request, "format": self.format_kwarg, "view": self}
def filter_queryset(self, queryset):
"""
@@ -157,7 +152,7 @@ class GenericAPIView(views.APIView):
"""
The paginator instance associated with the view, or `None`.
"""
- if not hasattr(self, '_paginator'):
+ if not hasattr(self, "_paginator"):
if self.pagination_class is None:
self._paginator = None
else:
@@ -183,47 +178,48 @@ class GenericAPIView(views.APIView):
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
-class CreateAPIView(mixins.CreateModelMixin,
- GenericAPIView):
+
+class CreateAPIView(mixins.CreateModelMixin, GenericAPIView):
"""
Concrete view for creating a model instance.
"""
+
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
-class ListAPIView(mixins.ListModelMixin,
- GenericAPIView):
+class ListAPIView(mixins.ListModelMixin, GenericAPIView):
"""
Concrete view for listing a queryset.
"""
+
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
-class RetrieveAPIView(mixins.RetrieveModelMixin,
- GenericAPIView):
+class RetrieveAPIView(mixins.RetrieveModelMixin, GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
+
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
-class DestroyAPIView(mixins.DestroyModelMixin,
- GenericAPIView):
+class DestroyAPIView(mixins.DestroyModelMixin, GenericAPIView):
"""
Concrete view for deleting a model instance.
"""
+
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
-class UpdateAPIView(mixins.UpdateModelMixin,
- GenericAPIView):
+class UpdateAPIView(mixins.UpdateModelMixin, GenericAPIView):
"""
Concrete view for updating a model instance.
"""
+
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
@@ -231,12 +227,11 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.partial_update(request, *args, **kwargs)
-class ListCreateAPIView(mixins.ListModelMixin,
- mixins.CreateModelMixin,
- GenericAPIView):
+class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
+
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
@@ -244,12 +239,13 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
-class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
- mixins.UpdateModelMixin,
- GenericAPIView):
+class RetrieveUpdateAPIView(
+ mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView
+):
"""
Concrete view for retrieving, updating a model instance.
"""
+
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -260,12 +256,13 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.partial_update(request, *args, **kwargs)
-class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
- mixins.DestroyModelMixin,
- GenericAPIView):
+class RetrieveDestroyAPIView(
+ mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView
+):
"""
Concrete view for retrieving or deleting a model instance.
"""
+
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@@ -273,13 +270,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
return self.destroy(request, *args, **kwargs)
-class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
- mixins.UpdateModelMixin,
- mixins.DestroyModelMixin,
- GenericAPIView):
+class RetrieveUpdateDestroyAPIView(
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ GenericAPIView,
+):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
+
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py
index 591073ba0..aad9743c3 100644
--- a/rest_framework/management/commands/generateschema.py
+++ b/rest_framework/management/commands/generateschema.py
@@ -2,7 +2,9 @@ from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi
from rest_framework.renderers import (
- CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer
+ CoreJSONRenderer,
+ JSONOpenAPIRenderer,
+ OpenAPIRenderer,
)
from rest_framework.schemas.generators import SchemaGenerator
@@ -11,31 +13,37 @@ class Command(BaseCommand):
help = "Generates configured API schema for project."
def add_arguments(self, parser):
- parser.add_argument('--title', dest="title", default=None, type=str)
- parser.add_argument('--url', dest="url", default=None, type=str)
- parser.add_argument('--description', dest="description", default=None, type=str)
- parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
+ parser.add_argument("--title", dest="title", default=None, type=str)
+ parser.add_argument("--url", dest="url", default=None, type=str)
+ parser.add_argument("--description", dest="description", default=None, type=str)
+ parser.add_argument(
+ "--format",
+ dest="format",
+ choices=["openapi", "openapi-json", "corejson"],
+ default="openapi",
+ type=str,
+ )
def handle(self, *args, **options):
- assert coreapi is not None, 'coreapi must be installed.'
+ assert coreapi is not None, "coreapi must be installed."
generator = SchemaGenerator(
- url=options['url'],
- title=options['title'],
- description=options['description']
+ url=options["url"],
+ title=options["title"],
+ description=options["description"],
)
schema = generator.get_schema(request=None, public=True)
- renderer = self.get_renderer(options['format'])
+ renderer = self.get_renderer(options["format"])
output = renderer.render(schema, renderer_context={})
- self.stdout.write(output.decode('utf-8'))
+ self.stdout.write(output.decode("utf-8"))
def get_renderer(self, format):
renderer_cls = {
- 'corejson': CoreJSONRenderer,
- 'openapi': OpenAPIRenderer,
- 'openapi-json': JSONOpenAPIRenderer,
+ "corejson": CoreJSONRenderer,
+ "openapi": OpenAPIRenderer,
+ "openapi-json": JSONOpenAPIRenderer,
}[format]
return renderer_cls()
diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py
index 9f9324469..76b0370f5 100644
--- a/rest_framework/metadata.py
+++ b/rest_framework/metadata.py
@@ -35,41 +35,46 @@ class SimpleMetadata(BaseMetadata):
There are not any formalized standards for `OPTIONS` responses
for us to base this on.
"""
- label_lookup = ClassLookupDict({
- serializers.Field: 'field',
- serializers.BooleanField: 'boolean',
- serializers.NullBooleanField: 'boolean',
- serializers.CharField: 'string',
- serializers.UUIDField: 'string',
- serializers.URLField: 'url',
- serializers.EmailField: 'email',
- serializers.RegexField: 'regex',
- serializers.SlugField: 'slug',
- serializers.IntegerField: 'integer',
- serializers.FloatField: 'float',
- serializers.DecimalField: 'decimal',
- serializers.DateField: 'date',
- serializers.DateTimeField: 'datetime',
- serializers.TimeField: 'time',
- serializers.ChoiceField: 'choice',
- serializers.MultipleChoiceField: 'multiple choice',
- serializers.FileField: 'file upload',
- serializers.ImageField: 'image upload',
- serializers.ListField: 'list',
- serializers.DictField: 'nested object',
- serializers.Serializer: 'nested object',
- })
+
+ label_lookup = ClassLookupDict(
+ {
+ serializers.Field: "field",
+ serializers.BooleanField: "boolean",
+ serializers.NullBooleanField: "boolean",
+ serializers.CharField: "string",
+ serializers.UUIDField: "string",
+ serializers.URLField: "url",
+ serializers.EmailField: "email",
+ serializers.RegexField: "regex",
+ serializers.SlugField: "slug",
+ serializers.IntegerField: "integer",
+ serializers.FloatField: "float",
+ serializers.DecimalField: "decimal",
+ serializers.DateField: "date",
+ serializers.DateTimeField: "datetime",
+ serializers.TimeField: "time",
+ serializers.ChoiceField: "choice",
+ serializers.MultipleChoiceField: "multiple choice",
+ serializers.FileField: "file upload",
+ serializers.ImageField: "image upload",
+ serializers.ListField: "list",
+ serializers.DictField: "nested object",
+ serializers.Serializer: "nested object",
+ }
+ )
def determine_metadata(self, request, view):
metadata = OrderedDict()
- metadata['name'] = view.get_view_name()
- metadata['description'] = view.get_view_description()
- metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
- metadata['parses'] = [parser.media_type for parser in view.parser_classes]
- if hasattr(view, 'get_serializer'):
+ metadata["name"] = view.get_view_name()
+ metadata["description"] = view.get_view_description()
+ metadata["renders"] = [
+ renderer.media_type for renderer in view.renderer_classes
+ ]
+ metadata["parses"] = [parser.media_type for parser in view.parser_classes]
+ if hasattr(view, "get_serializer"):
actions = self.determine_actions(request, view)
if actions:
- metadata['actions'] = actions
+ metadata["actions"] = actions
return metadata
def determine_actions(self, request, view):
@@ -78,14 +83,14 @@ class SimpleMetadata(BaseMetadata):
the fields that are accepted for 'PUT' and 'POST' methods.
"""
actions = {}
- for method in {'PUT', 'POST'} & set(view.allowed_methods):
+ for method in {"PUT", "POST"} & set(view.allowed_methods):
view.request = clone_request(request, method)
try:
# Test global permissions
- if hasattr(view, 'check_permissions'):
+ if hasattr(view, "check_permissions"):
view.check_permissions(view.request)
# Test object permissions
- if method == 'PUT' and hasattr(view, 'get_object'):
+ if method == "PUT" and hasattr(view, "get_object"):
view.get_object()
except (exceptions.APIException, PermissionDenied, Http404):
pass
@@ -104,15 +109,17 @@ class SimpleMetadata(BaseMetadata):
Given an instance of a serializer, return a dictionary of metadata
about its fields.
"""
- if hasattr(serializer, 'child'):
+ if hasattr(serializer, "child"):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
- return OrderedDict([
- (field_name, self.get_field_info(field))
- for field_name, field in serializer.fields.items()
- if not isinstance(field, serializers.HiddenField)
- ])
+ return OrderedDict(
+ [
+ (field_name, self.get_field_info(field))
+ for field_name, field in serializer.fields.items()
+ if not isinstance(field, serializers.HiddenField)
+ ]
+ )
def get_field_info(self, field):
"""
@@ -120,32 +127,40 @@ class SimpleMetadata(BaseMetadata):
of metadata about it.
"""
field_info = OrderedDict()
- field_info['type'] = self.label_lookup[field]
- field_info['required'] = getattr(field, 'required', False)
+ field_info["type"] = self.label_lookup[field]
+ field_info["required"] = getattr(field, "required", False)
attrs = [
- 'read_only', 'label', 'help_text',
- 'min_length', 'max_length',
- 'min_value', 'max_value'
+ "read_only",
+ "label",
+ "help_text",
+ "min_length",
+ "max_length",
+ "min_value",
+ "max_value",
]
for attr in attrs:
value = getattr(field, attr, None)
- if value is not None and value != '':
+ if value is not None and value != "":
field_info[attr] = force_text(value, strings_only=True)
- if getattr(field, 'child', None):
- field_info['child'] = self.get_field_info(field.child)
- elif getattr(field, 'fields', None):
- field_info['children'] = self.get_serializer_info(field)
+ if getattr(field, "child", None):
+ field_info["child"] = self.get_field_info(field.child)
+ elif getattr(field, "fields", None):
+ field_info["children"] = self.get_serializer_info(field)
- if (not field_info.get('read_only') and
- not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
- hasattr(field, 'choices')):
- field_info['choices'] = [
+ if (
+ not field_info.get("read_only")
+ and not isinstance(
+ field, (serializers.RelatedField, serializers.ManyRelatedField)
+ )
+ and hasattr(field, "choices")
+ ):
+ field_info["choices"] = [
{
- 'value': choice_value,
- 'display_name': force_text(choice_name, strings_only=True)
+ "value": choice_value,
+ "display_name": force_text(choice_name, strings_only=True),
}
for choice_value, choice_name in field.choices.items()
]
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index de10d6930..4855140ad 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -15,19 +15,22 @@ class CreateModelMixin(object):
"""
Create a model instance.
"""
+
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
- return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
+ return Response(
+ serializer.data, status=status.HTTP_201_CREATED, headers=headers
+ )
def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
try:
- return {'Location': str(data[api_settings.URL_FIELD_NAME])}
+ return {"Location": str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
@@ -36,6 +39,7 @@ class ListModelMixin(object):
"""
List a queryset.
"""
+
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
@@ -52,6 +56,7 @@ class RetrieveModelMixin(object):
"""
Retrieve a model instance.
"""
+
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
@@ -62,14 +67,15 @@ class UpdateModelMixin(object):
"""
Update a model instance.
"""
+
def update(self, request, *args, **kwargs):
- partial = kwargs.pop('partial', False)
+ partial = kwargs.pop("partial", False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
- if getattr(instance, '_prefetched_objects_cache', None):
+ if getattr(instance, "_prefetched_objects_cache", None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
@@ -80,7 +86,7 @@ class UpdateModelMixin(object):
serializer.save()
def partial_update(self, request, *args, **kwargs):
- kwargs['partial'] = True
+ kwargs["partial"] = True
return self.update(request, *args, **kwargs)
@@ -88,6 +94,7 @@ class DestroyModelMixin(object):
"""
Destroy a model instance.
"""
+
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py
index ca1b59f12..e9d89dd14 100644
--- a/rest_framework/negotiation.py
+++ b/rest_framework/negotiation.py
@@ -9,16 +9,18 @@ from django.http import Http404
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import (
- _MediaType, media_type_matches, order_by_precedence
+ _MediaType,
+ media_type_matches,
+ order_by_precedence,
)
class BaseContentNegotiation(object):
def select_parser(self, request, parsers):
- raise NotImplementedError('.select_parser() must be implemented')
+ raise NotImplementedError(".select_parser() must be implemented")
def select_renderer(self, request, renderers, format_suffix=None):
- raise NotImplementedError('.select_renderer() must be implemented')
+ raise NotImplementedError(".select_renderer() must be implemented")
class DefaultContentNegotiation(BaseContentNegotiation):
@@ -59,16 +61,20 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Return the most specific media type as accepted.
media_type_wrapper = _MediaType(media_type)
if (
- _MediaType(renderer.media_type).precedence >
- media_type_wrapper.precedence
+ _MediaType(renderer.media_type).precedence
+ > media_type_wrapper.precedence
):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
- full_media_type = ';'.join(
- (renderer.media_type,) +
- tuple('{0}={1}'.format(
- key, value.decode(HTTP_HEADER_ENCODING))
- for key, value in media_type_wrapper.params.items()))
+ full_media_type = ";".join(
+ (renderer.media_type,)
+ + tuple(
+ "{0}={1}".format(
+ key, value.decode(HTTP_HEADER_ENCODING)
+ )
+ for key, value in media_type_wrapper.params.items()
+ )
+ )
return renderer, full_media_type
else:
# Eg client requests 'application/json; indent=8'
@@ -82,8 +88,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
If there is a '.json' style format suffix, filter the renderers
so that we only negotiation against those that accept that format.
"""
- renderers = [renderer for renderer in renderers
- if renderer.format == format]
+ renderers = [renderer for renderer in renderers if renderer.format == format]
if not renderers:
raise Http404
return renderers
@@ -93,5 +98,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Given the incoming request, return a tokenized list of media
type strings.
"""
- header = request.META.get('HTTP_ACCEPT', '*/*')
- return [token.strip() for token in header.split(',')]
+ header = request.META.get("HTTP_ACCEPT", "*/*")
+ return [token.strip() for token in header.split(",")]
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index b11d7cdf3..5b61a4b79 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -8,8 +8,7 @@ from __future__ import unicode_literals
from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple
-from django.core.paginator import InvalidPage
-from django.core.paginator import Paginator as DjangoPaginator
+from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
from django.template import loader
from django.utils import six
from django.utils.encoding import force_text
@@ -83,10 +82,7 @@ def _get_displayed_page_numbers(current, final):
included.add(final - 2)
# Now sort the page numbers and drop anything outside the limits.
- included = [
- idx for idx in sorted(list(included))
- if 0 < idx <= final
- ]
+ included = [idx for idx in sorted(list(included)) if 0 < idx <= final]
# Finally insert any `...` breaks
if current > 4:
@@ -110,7 +106,7 @@ def _get_page_links(page_numbers, current, url_func):
url=url_func(page_number),
number=page_number,
is_active=(page_number == current),
- is_break=False
+ is_break=False,
)
page_links.append(page_link)
return page_links
@@ -121,14 +117,15 @@ def _reverse_ordering(ordering_tuple):
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""
+
def invert(x):
- return x[1:] if x.startswith('-') else '-' + x
+ return x[1:] if x.startswith("-") else "-" + x
return tuple([invert(item) for item in ordering_tuple])
-Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
-PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
+Cursor = namedtuple("Cursor", ["offset", "reverse", "position"])
+PageLink = namedtuple("PageLink", ["url", "number", "is_active", "is_break"])
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
@@ -137,19 +134,23 @@ class BasePagination(object):
display_page_controls = False
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
- raise NotImplementedError('paginate_queryset() must be implemented.')
+ raise NotImplementedError("paginate_queryset() must be implemented.")
def get_paginated_response(self, data): # pragma: no cover
- raise NotImplementedError('get_paginated_response() must be implemented.')
+ raise NotImplementedError("get_paginated_response() must be implemented.")
def to_html(self): # pragma: no cover
- raise NotImplementedError('to_html() must be implemented to display page controls.')
+ raise NotImplementedError(
+ "to_html() must be implemented to display page controls."
+ )
def get_results(self, data):
- return data['results']
+ return data["results"]
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
return []
@@ -161,6 +162,7 @@ class PageNumberPagination(BasePagination):
http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
+
# The default page size.
# Defaults to `None`, meaning pagination is disabled.
page_size = api_settings.PAGE_SIZE
@@ -168,23 +170,23 @@ class PageNumberPagination(BasePagination):
django_paginator_class = DjangoPaginator
# Client can control the page using this query parameter.
- page_query_param = 'page'
- page_query_description = _('A page number within the paginated result set.')
+ page_query_param = "page"
+ page_query_description = _("A page number within the paginated result set.")
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
- page_size_query_description = _('Number of results to return per page.')
+ page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
- last_page_strings = ('last',)
+ last_page_strings = ("last",)
- template = 'rest_framework/pagination/numbers.html'
+ template = "rest_framework/pagination/numbers.html"
- invalid_page_message = _('Invalid page.')
+ invalid_page_message = _("Invalid page.")
def paginate_queryset(self, queryset, request, view=None):
"""
@@ -216,12 +218,16 @@ class PageNumberPagination(BasePagination):
return list(self.page)
def get_paginated_response(self, data):
- return Response(OrderedDict([
- ('count', self.page.paginator.count),
- ('next', self.get_next_link()),
- ('previous', self.get_previous_link()),
- ('results', data)
- ]))
+ return Response(
+ OrderedDict(
+ [
+ ("count", self.page.paginator.count),
+ ("next", self.get_next_link()),
+ ("previous", self.get_previous_link()),
+ ("results", data),
+ ]
+ )
+ )
def get_page_size(self, request):
if self.page_size_query_param:
@@ -229,7 +235,7 @@ class PageNumberPagination(BasePagination):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
- cutoff=self.max_page_size
+ cutoff=self.max_page_size,
)
except (KeyError, ValueError):
pass
@@ -267,9 +273,9 @@ class PageNumberPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
- 'previous_url': self.get_previous_link(),
- 'next_url': self.get_next_link(),
- 'page_links': page_links
+ "previous_url": self.get_previous_link(),
+ "next_url": self.get_next_link(),
+ "page_links": page_links,
}
def to_html(self):
@@ -278,17 +284,20 @@ class PageNumberPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
fields = [
coreapi.Field(
name=self.page_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.Integer(
- title='Page',
- description=force_text(self.page_query_description)
- )
+ title="Page", description=force_text(self.page_query_description)
+ ),
)
]
if self.page_size_query_param is not None:
@@ -296,11 +305,11 @@ class PageNumberPagination(BasePagination):
coreapi.Field(
name=self.page_size_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.Integer(
- title='Page size',
- description=force_text(self.page_size_query_description)
- )
+ title="Page size",
+ description=force_text(self.page_size_query_description),
+ ),
)
)
return fields
@@ -313,13 +322,14 @@ class LimitOffsetPagination(BasePagination):
http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
+
default_limit = api_settings.PAGE_SIZE
- limit_query_param = 'limit'
- limit_query_description = _('Number of results to return per page.')
- offset_query_param = 'offset'
- offset_query_description = _('The initial index from which to return the results.')
+ limit_query_param = "limit"
+ limit_query_description = _("Number of results to return per page.")
+ offset_query_param = "offset"
+ offset_query_description = _("The initial index from which to return the results.")
max_limit = None
- template = 'rest_framework/pagination/numbers.html'
+ template = "rest_framework/pagination/numbers.html"
def paginate_queryset(self, queryset, request, view=None):
self.count = self.get_count(queryset)
@@ -334,15 +344,19 @@ class LimitOffsetPagination(BasePagination):
if self.count == 0 or self.offset > self.count:
return []
- return list(queryset[self.offset:self.offset + self.limit])
+ return list(queryset[self.offset : self.offset + self.limit])
def get_paginated_response(self, data):
- return Response(OrderedDict([
- ('count', self.count),
- ('next', self.get_next_link()),
- ('previous', self.get_previous_link()),
- ('results', data)
- ]))
+ return Response(
+ OrderedDict(
+ [
+ ("count", self.count),
+ ("next", self.get_next_link()),
+ ("previous", self.get_previous_link()),
+ ("results", data),
+ ]
+ )
+ )
def get_limit(self, request):
if self.limit_query_param:
@@ -350,7 +364,7 @@ class LimitOffsetPagination(BasePagination):
return _positive_int(
request.query_params[self.limit_query_param],
strict=True,
- cutoff=self.max_limit
+ cutoff=self.max_limit,
)
except (KeyError, ValueError):
pass
@@ -359,9 +373,7 @@ class LimitOffsetPagination(BasePagination):
def get_offset(self, request):
try:
- return _positive_int(
- request.query_params[self.offset_query_param],
- )
+ return _positive_int(request.query_params[self.offset_query_param])
except (KeyError, ValueError):
return 0
@@ -399,10 +411,9 @@ class LimitOffsetPagination(BasePagination):
# plus the number of pages up to the current offset.
# When offset is not strictly divisible by the limit then we may
# end up introducing an extra page as an artifact.
- final = (
- _divide_with_ceil(self.count - self.offset, self.limit) +
- _divide_with_ceil(self.offset, self.limit)
- )
+ final = _divide_with_ceil(
+ self.count - self.offset, self.limit
+ ) + _divide_with_ceil(self.offset, self.limit)
if final < 1:
final = 1
@@ -424,9 +435,9 @@ class LimitOffsetPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
- 'previous_url': self.get_previous_link(),
- 'next_url': self.get_next_link(),
- 'page_links': page_links
+ "previous_url": self.get_previous_link(),
+ "next_url": self.get_next_link(),
+ "page_links": page_links,
}
def to_html(self):
@@ -435,27 +446,30 @@ class LimitOffsetPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.limit_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.Integer(
- title='Limit',
- description=force_text(self.limit_query_description)
- )
+ title="Limit", description=force_text(self.limit_query_description)
+ ),
),
coreapi.Field(
name=self.offset_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.Integer(
- title='Offset',
- description=force_text(self.offset_query_description)
- )
- )
+ title="Offset",
+ description=force_text(self.offset_query_description),
+ ),
+ ),
]
def get_count(self, queryset):
@@ -474,17 +488,18 @@ class CursorPagination(BasePagination):
For an overview of the position/offset style we use, see this post:
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
"""
- cursor_query_param = 'cursor'
- cursor_query_description = _('The pagination cursor value.')
+
+ cursor_query_param = "cursor"
+ cursor_query_description = _("The pagination cursor value.")
page_size = api_settings.PAGE_SIZE
- invalid_cursor_message = _('Invalid cursor')
- ordering = '-created'
- template = 'rest_framework/pagination/previous_and_next.html'
+ invalid_cursor_message = _("Invalid cursor")
+ ordering = "-created"
+ template = "rest_framework/pagination/previous_and_next.html"
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
- page_size_query_description = _('Number of results to return per page.')
+ page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
@@ -519,27 +534,29 @@ class CursorPagination(BasePagination):
# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
order = self.ordering[0]
- is_reversed = order.startswith('-')
- order_attr = order.lstrip('-')
+ is_reversed = order.startswith("-")
+ order_attr = order.lstrip("-")
# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
- kwargs = {order_attr + '__lt': current_position}
+ kwargs = {order_attr + "__lt": current_position}
else:
- kwargs = {order_attr + '__gt': current_position}
+ kwargs = {order_attr + "__gt": current_position}
queryset = queryset.filter(**kwargs)
# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
# page following on from this one.
- results = list(queryset[offset:offset + self.page_size + 1])
- self.page = list(results[:self.page_size])
+ results = list(queryset[offset : offset + self.page_size + 1])
+ self.page = list(results[: self.page_size])
# Determine the position of the final item following the page.
if len(results) > len(self.page):
has_following_position = True
- following_position = self._get_position_from_instance(results[-1], self.ordering)
+ following_position = self._get_position_from_instance(
+ results[-1], self.ordering
+ )
else:
has_following_position = False
following_position = None
@@ -578,7 +595,7 @@ class CursorPagination(BasePagination):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
- cutoff=self.max_page_size
+ cutoff=self.max_page_size,
)
except (KeyError, ValueError):
pass
@@ -686,8 +703,9 @@ class CursorPagination(BasePagination):
Return a tuple of strings, that may be used in an `order_by` method.
"""
ordering_filters = [
- filter_cls for filter_cls in getattr(view, 'filter_backends', [])
- if hasattr(filter_cls, 'get_ordering')
+ filter_cls
+ for filter_cls in getattr(view, "filter_backends", [])
+ if hasattr(filter_cls, "get_ordering")
]
if ordering_filters:
@@ -697,29 +715,27 @@ class CursorPagination(BasePagination):
filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, (
- 'Using cursor pagination, but filter class {filter_cls} '
- 'returned a `None` ordering.'.format(
- filter_cls=filter_cls.__name__
- )
+ "Using cursor pagination, but filter class {filter_cls} "
+ "returned a `None` ordering.".format(filter_cls=filter_cls.__name__)
)
else:
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
assert ordering is not None, (
- 'Using cursor pagination, but no ordering attribute was declared '
- 'on the pagination class.'
+ "Using cursor pagination, but no ordering attribute was declared "
+ "on the pagination class."
)
- assert '__' not in ordering, (
- 'Cursor pagination does not support double underscore lookups '
- 'for orderings. Orderings should be an unchanging, unique or '
+ assert "__" not in ordering, (
+ "Cursor pagination does not support double underscore lookups "
+ "for orderings. Orderings should be an unchanging, unique or "
'nearly-unique field on the model, such as "-created" or "pk".'
)
- assert isinstance(ordering, (six.string_types, list, tuple)), (
- 'Invalid ordering. Expected string or tuple, but got {type}'.format(
- type=type(ordering).__name__
- )
+ assert isinstance(
+ ordering, (six.string_types, list, tuple)
+ ), "Invalid ordering. Expected string or tuple, but got {type}".format(
+ type=type(ordering).__name__
)
if isinstance(ordering, six.string_types):
@@ -736,16 +752,16 @@ class CursorPagination(BasePagination):
return None
try:
- querystring = b64decode(encoded.encode('ascii')).decode('ascii')
+ querystring = b64decode(encoded.encode("ascii")).decode("ascii")
tokens = urlparse.parse_qs(querystring, keep_blank_values=True)
- offset = tokens.get('o', ['0'])[0]
+ offset = tokens.get("o", ["0"])[0]
offset = _positive_int(offset, cutoff=self.offset_cutoff)
- reverse = tokens.get('r', ['0'])[0]
+ reverse = tokens.get("r", ["0"])[0]
reverse = bool(int(reverse))
- position = tokens.get('p', [None])[0]
+ position = tokens.get("p", [None])[0]
except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message)
@@ -757,18 +773,18 @@ class CursorPagination(BasePagination):
"""
tokens = {}
if cursor.offset != 0:
- tokens['o'] = str(cursor.offset)
+ tokens["o"] = str(cursor.offset)
if cursor.reverse:
- tokens['r'] = '1'
+ tokens["r"] = "1"
if cursor.position is not None:
- tokens['p'] = cursor.position
+ tokens["p"] = cursor.position
querystring = urlparse.urlencode(tokens, doseq=True)
- encoded = b64encode(querystring.encode('ascii')).decode('ascii')
+ encoded = b64encode(querystring.encode("ascii")).decode("ascii")
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def _get_position_from_instance(self, instance, ordering):
- field_name = ordering[0].lstrip('-')
+ field_name = ordering[0].lstrip("-")
if isinstance(instance, dict):
attr = instance[field_name]
else:
@@ -776,16 +792,20 @@ class CursorPagination(BasePagination):
return six.text_type(attr)
def get_paginated_response(self, data):
- return Response(OrderedDict([
- ('next', self.get_next_link()),
- ('previous', self.get_previous_link()),
- ('results', data)
- ]))
+ return Response(
+ OrderedDict(
+ [
+ ("next", self.get_next_link()),
+ ("previous", self.get_previous_link()),
+ ("results", data),
+ ]
+ )
+ )
def get_html_context(self):
return {
- 'previous_url': self.get_previous_link(),
- 'next_url': self.get_next_link()
+ "previous_url": self.get_previous_link(),
+ "next_url": self.get_next_link(),
}
def to_html(self):
@@ -794,17 +814,21 @@ class CursorPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
- assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
- assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
+ assert (
+ coreapi is not None
+ ), "coreapi must be installed to use `get_schema_fields()`"
+ assert (
+ coreschema is not None
+ ), "coreschema must be installed to use `get_schema_fields()`"
fields = [
coreapi.Field(
name=self.cursor_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.String(
- title='Cursor',
- description=force_text(self.cursor_query_description)
- )
+ title="Cursor",
+ description=force_text(self.cursor_query_description),
+ ),
)
]
if self.page_size_query_param is not None:
@@ -812,11 +836,11 @@ class CursorPagination(BasePagination):
coreapi.Field(
name=self.page_size_query_param,
required=False,
- location='query',
+ location="query",
schema=coreschema.Integer(
- title='Page size',
- description=force_text(self.page_size_query_description)
- )
+ title="Page size",
+ description=force_text(self.page_size_query_description),
+ ),
)
)
return fields
diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py
index 35d0d1aa7..22307ae2a 100644
--- a/rest_framework/parsers.py
+++ b/rest_framework/parsers.py
@@ -11,10 +11,12 @@ import codecs
from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
-from django.http.multipartparser import ChunkIter
-from django.http.multipartparser import \
- MultiPartParser as DjangoMultiPartParser
-from django.http.multipartparser import MultiPartParserError, parse_header
+from django.http.multipartparser import (
+ ChunkIter,
+ MultiPartParser as DjangoMultiPartParser,
+ MultiPartParserError,
+ parse_header,
+)
from django.utils import six
from django.utils.encoding import force_text
from django.utils.six.moves.urllib import parse as urlparse
@@ -36,6 +38,7 @@ class BaseParser(object):
All parsers should extend `BaseParser`, specifying a `media_type`
attribute, and overriding the `.parse()` method.
"""
+
media_type = None
def parse(self, stream, media_type=None, parser_context=None):
@@ -51,7 +54,8 @@ class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
- media_type = 'application/json'
+
+ media_type = "application/json"
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
@@ -60,21 +64,22 @@ class JSONParser(BaseParser):
Parses the incoming bytestream as JSON and returns the resulting data.
"""
parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
try:
decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc:
- raise ParseError('JSON parse error - %s' % six.text_type(exc))
+ raise ParseError("JSON parse error - %s" % six.text_type(exc))
class FormParser(BaseParser):
"""
Parser for form data.
"""
- media_type = 'application/x-www-form-urlencoded'
+
+ media_type = "application/x-www-form-urlencoded"
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -82,7 +87,7 @@ class FormParser(BaseParser):
and returns the resulting QueryDict.
"""
parser_context = parser_context or {}
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
data = QueryDict(stream.read(), encoding=encoding)
return data
@@ -91,7 +96,8 @@ class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
- media_type = 'multipart/form-data'
+
+ media_type = "multipart/form-data"
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -102,10 +108,10 @@ class MultiPartParser(BaseParser):
`.files` will be a `QueryDict` containing all the form files.
"""
parser_context = parser_context or {}
- request = parser_context['request']
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ request = parser_context["request"]
+ encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META.copy()
- meta['CONTENT_TYPE'] = media_type
+ meta["CONTENT_TYPE"] = media_type
upload_handlers = request.upload_handlers
try:
@@ -113,17 +119,18 @@ class MultiPartParser(BaseParser):
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
- raise ParseError('Multipart form parse error - %s' % six.text_type(exc))
+ raise ParseError("Multipart form parse error - %s" % six.text_type(exc))
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
- media_type = '*/*'
+
+ media_type = "*/*"
errors = {
- 'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
- 'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
+ "unhandled": "FileUpload parse error - none of upload handlers can handle the stream",
+ "no_filename": "Missing filename. Request should include a Content-Disposition header with a filename parameter.",
}
def parse(self, stream, media_type=None, parser_context=None):
@@ -135,34 +142,32 @@ class FileUploadParser(BaseParser):
`.files` will be a `QueryDict` containing one 'file' element.
"""
parser_context = parser_context or {}
- request = parser_context['request']
- encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
+ request = parser_context["request"]
+ encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META
upload_handlers = request.upload_handlers
filename = self.get_filename(stream, media_type, parser_context)
if not filename:
- raise ParseError(self.errors['no_filename'])
+ raise ParseError(self.errors["no_filename"])
# Note that this code is extracted from Django's handling of
# file uploads in MultiPartParser.
- content_type = meta.get('HTTP_CONTENT_TYPE',
- meta.get('CONTENT_TYPE', ''))
+ content_type = meta.get("HTTP_CONTENT_TYPE", meta.get("CONTENT_TYPE", ""))
try:
- content_length = int(meta.get('HTTP_CONTENT_LENGTH',
- meta.get('CONTENT_LENGTH', 0)))
+ content_length = int(
+ meta.get("HTTP_CONTENT_LENGTH", meta.get("CONTENT_LENGTH", 0))
+ )
except (ValueError, TypeError):
content_length = None
# See if the handler will want to take care of the parsing.
for handler in upload_handlers:
- result = handler.handle_raw_input(stream,
- meta,
- content_length,
- None,
- encoding)
+ result = handler.handle_raw_input(
+ stream, meta, content_length, None, encoding
+ )
if result is not None:
- return DataAndFiles({}, {'file': result[1]})
+ return DataAndFiles({}, {"file": result[1]})
# This is the standard case.
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
@@ -172,10 +177,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers):
try:
- handler.new_file(None, filename, content_type,
- content_length, encoding)
+ handler.new_file(None, filename, content_type, content_length, encoding)
except StopFutureHandlers:
- upload_handlers = upload_handlers[:index + 1]
+ upload_handlers = upload_handlers[: index + 1]
break
for chunk in chunks:
@@ -189,9 +193,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers):
file_obj = handler.file_complete(counters[index])
if file_obj is not None:
- return DataAndFiles({}, {'file': file_obj})
+ return DataAndFiles({}, {"file": file_obj})
- raise ParseError(self.errors['unhandled'])
+ raise ParseError(self.errors["unhandled"])
def get_filename(self, stream, media_type, parser_context):
"""
@@ -199,17 +203,17 @@ class FileUploadParser(BaseParser):
Then tries to parse Content-Disposition header.
"""
try:
- return parser_context['kwargs']['filename']
+ return parser_context["kwargs"]["filename"]
except KeyError:
pass
try:
- meta = parser_context['request'].META
- disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
+ meta = parser_context["request"].META
+ disposition = parse_header(meta["HTTP_CONTENT_DISPOSITION"].encode("utf-8"))
filename_parm = disposition[1]
- if 'filename*' in filename_parm:
+ if "filename*" in filename_parm:
return self.get_encoded_filename(filename_parm)
- return force_text(filename_parm['filename'])
+ return force_text(filename_parm["filename"])
except (AttributeError, KeyError, ValueError):
pass
@@ -218,10 +222,10 @@ class FileUploadParser(BaseParser):
Handle encoded filenames per RFC6266. See also:
https://tools.ietf.org/html/rfc2231#section-4
"""
- encoded_filename = force_text(filename_parm['filename*'])
+ encoded_filename = force_text(filename_parm["filename*"])
try:
- charset, lang, filename = encoded_filename.split('\'', 2)
+ charset, lang, filename = encoded_filename.split("'", 2)
filename = urlparse.unquote(filename)
except (ValueError, LookupError):
- filename = force_text(filename_parm['filename'])
+ filename = force_text(filename_parm["filename"])
return filename
diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py
index 5d75f54ba..aff42caab 100644
--- a/rest_framework/permissions.py
+++ b/rest_framework/permissions.py
@@ -8,7 +8,8 @@ from django.utils import six
from rest_framework import exceptions
-SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
+
+SAFE_METHODS = ("GET", "HEAD", "OPTIONS")
class OperationHolderMixin:
@@ -56,16 +57,14 @@ class AND:
self.op2 = op2
def has_permission(self, request, view):
- return (
- self.op1.has_permission(request, view) and
- self.op2.has_permission(request, view)
+ return self.op1.has_permission(request, view) and self.op2.has_permission(
+ request, view
)
def has_object_permission(self, request, view, obj):
- return (
- self.op1.has_object_permission(request, view, obj) and
- self.op2.has_object_permission(request, view, obj)
- )
+ return self.op1.has_object_permission(
+ request, view, obj
+ ) and self.op2.has_object_permission(request, view, obj)
class OR:
@@ -74,16 +73,14 @@ class OR:
self.op2 = op2
def has_permission(self, request, view):
- return (
- self.op1.has_permission(request, view) or
- self.op2.has_permission(request, view)
+ return self.op1.has_permission(request, view) or self.op2.has_permission(
+ request, view
)
def has_object_permission(self, request, view, obj):
- return (
- self.op1.has_object_permission(request, view, obj) or
- self.op2.has_object_permission(request, view, obj)
- )
+ return self.op1.has_object_permission(
+ request, view, obj
+ ) or self.op2.has_object_permission(request, view, obj)
class NOT:
@@ -157,9 +154,9 @@ class IsAuthenticatedOrReadOnly(BasePermission):
def has_permission(self, request, view):
return bool(
- request.method in SAFE_METHODS or
- request.user and
- request.user.is_authenticated
+ request.method in SAFE_METHODS
+ or request.user
+ and request.user.is_authenticated
)
@@ -179,13 +176,13 @@ class DjangoModelPermissions(BasePermission):
# Override this if you need to also provide 'view' permissions,
# or if you want to provide custom permission codes.
perms_map = {
- 'GET': [],
- 'OPTIONS': [],
- 'HEAD': [],
- 'POST': ['%(app_label)s.add_%(model_name)s'],
- 'PUT': ['%(app_label)s.change_%(model_name)s'],
- 'PATCH': ['%(app_label)s.change_%(model_name)s'],
- 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ "GET": [],
+ "OPTIONS": [],
+ "HEAD": [],
+ "POST": ["%(app_label)s.add_%(model_name)s"],
+ "PUT": ["%(app_label)s.change_%(model_name)s"],
+ "PATCH": ["%(app_label)s.change_%(model_name)s"],
+ "DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
authenticated_users_only = True
@@ -196,8 +193,8 @@ class DjangoModelPermissions(BasePermission):
codes that the user is required to have.
"""
kwargs = {
- 'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.model_name
+ "app_label": model_cls._meta.app_label,
+ "model_name": model_cls._meta.model_name,
}
if method not in self.perms_map:
@@ -206,16 +203,19 @@ class DjangoModelPermissions(BasePermission):
return [perm % kwargs for perm in self.perms_map[method]]
def _queryset(self, view):
- assert hasattr(view, 'get_queryset') \
- or getattr(view, 'queryset', None) is not None, (
- 'Cannot apply {} on a view that does not set '
- '`.queryset` or have a `.get_queryset()` method.'
- ).format(self.__class__.__name__)
+ assert (
+ hasattr(view, "get_queryset") or getattr(view, "queryset", None) is not None
+ ), (
+ "Cannot apply {} on a view that does not set "
+ "`.queryset` or have a `.get_queryset()` method."
+ ).format(
+ self.__class__.__name__
+ )
- if hasattr(view, 'get_queryset'):
+ if hasattr(view, "get_queryset"):
queryset = view.get_queryset()
- assert queryset is not None, (
- '{}.get_queryset() returned None'.format(view.__class__.__name__)
+ assert queryset is not None, "{}.get_queryset() returned None".format(
+ view.__class__.__name__
)
return queryset
return view.queryset
@@ -223,11 +223,12 @@ class DjangoModelPermissions(BasePermission):
def has_permission(self, request, view):
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
- if getattr(view, '_ignore_model_permissions', False):
+ if getattr(view, "_ignore_model_permissions", False):
return True
if not request.user or (
- not request.user.is_authenticated and self.authenticated_users_only):
+ not request.user.is_authenticated and self.authenticated_users_only
+ ):
return False
queryset = self._queryset(view)
@@ -241,6 +242,7 @@ class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
Similar to DjangoModelPermissions, except that anonymous users are
allowed read-only access.
"""
+
authenticated_users_only = False
@@ -255,20 +257,21 @@ class DjangoObjectPermissions(DjangoModelPermissions):
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
+
perms_map = {
- 'GET': [],
- 'OPTIONS': [],
- 'HEAD': [],
- 'POST': ['%(app_label)s.add_%(model_name)s'],
- 'PUT': ['%(app_label)s.change_%(model_name)s'],
- 'PATCH': ['%(app_label)s.change_%(model_name)s'],
- 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ "GET": [],
+ "OPTIONS": [],
+ "HEAD": [],
+ "POST": ["%(app_label)s.add_%(model_name)s"],
+ "PUT": ["%(app_label)s.change_%(model_name)s"],
+ "PATCH": ["%(app_label)s.change_%(model_name)s"],
+ "DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
def get_required_object_permissions(self, method, model_cls):
kwargs = {
- 'app_label': model_cls._meta.app_label,
- 'model_name': model_cls._meta.model_name
+ "app_label": model_cls._meta.app_label,
+ "model_name": model_cls._meta.model_name,
}
if method not in self.perms_map:
@@ -294,7 +297,7 @@ class DjangoObjectPermissions(DjangoModelPermissions):
# to make another lookup.
raise Http404
- read_perms = self.get_required_object_permissions('GET', model_cls)
+ read_perms = self.get_required_object_permissions("GET", model_cls)
if not user.has_perms(read_perms, obj):
raise Http404
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 31c1e7561..ef2dd53f2 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -9,14 +9,16 @@ from django.db.models import Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils import six
-from django.utils.encoding import (
- python_2_unicode_compatible, smart_text, uri_to_iri
-)
+from django.utils.encoding import python_2_unicode_compatible, smart_text, uri_to_iri
from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import (
- Field, empty, get_attribute, is_simple_callable, iter_options
+ Field,
+ empty,
+ get_attribute,
+ is_simple_callable,
+ iter_options,
)
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
@@ -28,7 +30,7 @@ def method_overridden(method_name, klass, instance):
Determine if a method has been overridden.
"""
method = getattr(klass, method_name)
- default_method = getattr(method, '__func__', method) # Python 3 compat
+ default_method = getattr(method, "__func__", method) # Python 3 compat
return default_method is not getattr(instance, method_name).__func__
@@ -52,13 +54,14 @@ class Hyperlink(six.text_type):
We use this for hyperlinked URLs that may render as a named link
in some contexts, or render as a plain URL in others.
"""
+
def __new__(self, url, obj):
ret = six.text_type.__new__(self, url)
ret.obj = obj
return ret
def __getnewargs__(self):
- return(str(self), self.name,)
+ return (str(self), self.name)
@property
def name(self):
@@ -77,6 +80,7 @@ class PKOnlyObject(object):
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
+
def __init__(self, pk):
self.pk = pk
@@ -87,9 +91,19 @@ class PKOnlyObject(object):
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
MANY_RELATION_KWARGS = (
- 'read_only', 'write_only', 'required', 'default', 'initial', 'source',
- 'label', 'help_text', 'style', 'error_messages', 'allow_empty',
- 'html_cutoff', 'html_cutoff_text'
+ "read_only",
+ "write_only",
+ "required",
+ "default",
+ "initial",
+ "source",
+ "label",
+ "help_text",
+ "style",
+ "error_messages",
+ "allow_empty",
+ "html_cutoff",
+ "html_cutoff_text",
)
@@ -99,34 +113,34 @@ class RelatedField(Field):
html_cutoff_text = None
def __init__(self, **kwargs):
- self.queryset = kwargs.pop('queryset', self.queryset)
+ self.queryset = kwargs.pop("queryset", self.queryset)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
- self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
+ self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
- 'html_cutoff_text',
- self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
+ "html_cutoff_text",
+ self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
)
- if not method_overridden('get_queryset', RelatedField, self):
- assert self.queryset is not None or kwargs.get('read_only', None), (
- 'Relational field must provide a `queryset` argument, '
- 'override `get_queryset`, or set read_only=`True`.'
+ if not method_overridden("get_queryset", RelatedField, self):
+ assert self.queryset is not None or kwargs.get("read_only", None), (
+ "Relational field must provide a `queryset` argument, "
+ "override `get_queryset`, or set read_only=`True`."
)
- assert not (self.queryset is not None and kwargs.get('read_only', None)), (
- 'Relational fields should not provide a `queryset` argument, '
- 'when setting read_only=`True`.'
+ assert not (self.queryset is not None and kwargs.get("read_only", None)), (
+ "Relational fields should not provide a `queryset` argument, "
+ "when setting read_only=`True`."
)
- kwargs.pop('many', None)
- kwargs.pop('allow_empty', None)
+ kwargs.pop("many", None)
+ kwargs.pop("allow_empty", None)
super(RelatedField, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set.
- if kwargs.pop('many', False):
+ if kwargs.pop("many", False):
return cls.many_init(*args, **kwargs)
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
@@ -147,7 +161,7 @@ class RelatedField(Field):
kwargs['child'] = cls()
return CustomManyRelatedField(*args, **kwargs)
"""
- list_kwargs = {'child_relation': cls(*args, **kwargs)}
+ list_kwargs = {"child_relation": cls(*args, **kwargs)}
for key in kwargs:
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
@@ -155,7 +169,7 @@ class RelatedField(Field):
def run_validation(self, data=empty):
# We force empty strings to None values for relational fields.
- if data == '':
+ if data == "":
data = None
return super(RelatedField, self).run_validation(data)
@@ -201,13 +215,12 @@ class RelatedField(Field):
if cutoff is not None:
queryset = queryset[:cutoff]
- return OrderedDict([
- (
- self.to_representation(item),
- self.display_value(item)
- )
- for item in queryset
- ])
+ return OrderedDict(
+ [
+ (self.to_representation(item), self.display_value(item))
+ for item in queryset
+ ]
+ )
@property
def choices(self):
@@ -221,7 +234,7 @@ class RelatedField(Field):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
- cutoff_text=self.html_cutoff_text
+ cutoff_text=self.html_cutoff_text,
)
def display_value(self, instance):
@@ -235,7 +248,7 @@ class StringRelatedField(RelatedField):
"""
def __init__(self, **kwargs):
- kwargs['read_only'] = True
+ kwargs["read_only"] = True
super(StringRelatedField, self).__init__(**kwargs)
def to_representation(self, value):
@@ -244,13 +257,13 @@ class StringRelatedField(RelatedField):
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
- 'required': _('This field is required.'),
- 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
- 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
+ "required": _("This field is required."),
+ "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'),
+ "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."),
}
def __init__(self, **kwargs):
- self.pk_field = kwargs.pop('pk_field', None)
+ self.pk_field = kwargs.pop("pk_field", None)
super(PrimaryKeyRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
@@ -262,9 +275,9 @@ class PrimaryKeyRelatedField(RelatedField):
try:
return self.get_queryset().get(pk=data)
except ObjectDoesNotExist:
- self.fail('does_not_exist', pk_value=data)
+ self.fail("does_not_exist", pk_value=data)
except (TypeError, ValueError):
- self.fail('incorrect_type', data_type=type(data).__name__)
+ self.fail("incorrect_type", data_type=type(data).__name__)
def to_representation(self, value):
if self.pk_field is not None:
@@ -273,24 +286,26 @@ class PrimaryKeyRelatedField(RelatedField):
class HyperlinkedRelatedField(RelatedField):
- lookup_field = 'pk'
+ lookup_field = "pk"
view_name = None
default_error_messages = {
- 'required': _('This field is required.'),
- 'no_match': _('Invalid hyperlink - No URL match.'),
- 'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
- 'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
- 'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
+ "required": _("This field is required."),
+ "no_match": _("Invalid hyperlink - No URL match."),
+ "incorrect_match": _("Invalid hyperlink - Incorrect URL match."),
+ "does_not_exist": _("Invalid hyperlink - Object does not exist."),
+ "incorrect_type": _(
+ "Incorrect type. Expected URL string, received {data_type}."
+ ),
}
def __init__(self, view_name=None, **kwargs):
if view_name is not None:
self.view_name = view_name
- assert self.view_name is not None, 'The `view_name` argument is required.'
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
- self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
- self.format = kwargs.pop('format', None)
+ assert self.view_name is not None, "The `view_name` argument is required."
+ self.lookup_field = kwargs.pop("lookup_field", self.lookup_field)
+ self.lookup_url_kwarg = kwargs.pop("lookup_url_kwarg", self.lookup_field)
+ self.format = kwargs.pop("format", None)
# We include this simply for dependency injection in tests.
# We can't add it as a class attributes or it would expect an
@@ -300,7 +315,7 @@ class HyperlinkedRelatedField(RelatedField):
super(HyperlinkedRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
- return self.lookup_field == 'pk'
+ return self.lookup_field == "pk"
def get_object(self, view_name, view_args, view_kwargs):
"""
@@ -330,7 +345,7 @@ class HyperlinkedRelatedField(RelatedField):
attributes are not configured to correctly match the URL conf.
"""
# Unsaved objects will not yet have a valid URL.
- if hasattr(obj, 'pk') and obj.pk in (None, ''):
+ if hasattr(obj, "pk") and obj.pk in (None, ""):
return None
lookup_value = getattr(obj, self.lookup_field)
@@ -338,25 +353,25 @@ class HyperlinkedRelatedField(RelatedField):
return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_internal_value(self, data):
- request = self.context.get('request', None)
+ request = self.context.get("request", None)
try:
- http_prefix = data.startswith(('http:', 'https:'))
+ http_prefix = data.startswith(("http:", "https:"))
except AttributeError:
- self.fail('incorrect_type', data_type=type(data).__name__)
+ self.fail("incorrect_type", data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
data = urlparse.urlparse(data).path
prefix = get_script_prefix()
if data.startswith(prefix):
- data = '/' + data[len(prefix):]
+ data = "/" + data[len(prefix) :]
data = uri_to_iri(data)
try:
match = resolve(data)
except Resolver404:
- self.fail('no_match')
+ self.fail("no_match")
try:
expected_viewname = request.versioning_scheme.get_versioned_viewname(
@@ -366,22 +381,22 @@ class HyperlinkedRelatedField(RelatedField):
expected_viewname = self.view_name
if match.view_name != expected_viewname:
- self.fail('incorrect_match')
+ self.fail("incorrect_match")
try:
return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
- self.fail('does_not_exist')
+ self.fail("does_not_exist")
def to_representation(self, value):
- assert 'request' in self.context, (
+ assert "request" in self.context, (
"`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
"the serializer." % self.__class__.__name__
)
- request = self.context['request']
- format = self.context.get('format', None)
+ request = self.context["request"]
+ format = self.context.get("format", None)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@@ -400,13 +415,13 @@ class HyperlinkedRelatedField(RelatedField):
url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch:
msg = (
- 'Could not resolve URL for hyperlinked relationship using '
+ "Could not resolve URL for hyperlinked relationship using "
'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
+ "model in your API, or incorrectly configured the "
+ "`lookup_field` attribute on this field."
)
- if value in ('', None):
- value_string = {'': 'the empty string', None: 'None'}[value]
+ if value in ("", None):
+ value_string = {"": "the empty string", None: "None"}[value]
msg += (
" WARNING: The value of the field on the model instance "
"was %s, which may be why it didn't match any "
@@ -429,9 +444,9 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
"""
def __init__(self, view_name=None, **kwargs):
- assert view_name is not None, 'The `view_name` argument is required.'
- kwargs['read_only'] = True
- kwargs['source'] = '*'
+ assert view_name is not None, "The `view_name` argument is required."
+ kwargs["read_only"] = True
+ kwargs["source"] = "*"
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
def use_pk_only_optimization(self):
@@ -445,13 +460,14 @@ class SlugRelatedField(RelatedField):
A read-write field that represents the target of the relationship
by a unique 'slug' attribute.
"""
+
default_error_messages = {
- 'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
- 'invalid': _('Invalid value.'),
+ "does_not_exist": _("Object with {slug_name}={value} does not exist."),
+ "invalid": _("Invalid value."),
}
def __init__(self, slug_field=None, **kwargs):
- assert slug_field is not None, 'The `slug_field` argument is required.'
+ assert slug_field is not None, "The `slug_field` argument is required."
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
@@ -459,9 +475,11 @@ class SlugRelatedField(RelatedField):
try:
return self.get_queryset().get(**{self.slug_field: data})
except ObjectDoesNotExist:
- self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data))
+ self.fail(
+ "does_not_exist", slug_name=self.slug_field, value=smart_text(data)
+ )
except (TypeError, ValueError):
- self.fail('invalid')
+ self.fail("invalid")
def to_representation(self, obj):
return getattr(obj, self.slug_field)
@@ -479,31 +497,32 @@ class ManyRelatedField(Field):
You shouldn't generally need to be using this class directly yourself,
and should instead simply set 'many=True' on the relationship.
"""
+
initial = []
default_empty_html = []
default_error_messages = {
- 'not_a_list': _('Expected a list of items but got type "{input_type}".'),
- 'empty': _('This list may not be empty.')
+ "not_a_list": _('Expected a list of items but got type "{input_type}".'),
+ "empty": _("This list may not be empty."),
}
html_cutoff = None
html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
- self.allow_empty = kwargs.pop('allow_empty', True)
+ self.allow_empty = kwargs.pop("allow_empty", True)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
- self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
+ self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
- 'html_cutoff_text',
- self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
+ "html_cutoff_text",
+ self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
)
- assert child_relation is not None, '`child_relation` is a required argument.'
+ assert child_relation is not None, "`child_relation` is a required argument."
super(ManyRelatedField, self).__init__(*args, **kwargs)
- self.child_relation.bind(field_name='', parent=self)
+ self.child_relation.bind(field_name="", parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
@@ -511,36 +530,30 @@ class ManyRelatedField(Field):
if html.is_html_input(dictionary):
# Don't return [] if the update is partial
if self.field_name not in dictionary:
- if getattr(self.root, 'partial', False):
+ if getattr(self.root, "partial", False):
return empty
return dictionary.getlist(self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
- if isinstance(data, six.text_type) or not hasattr(data, '__iter__'):
- self.fail('not_a_list', input_type=type(data).__name__)
+ if isinstance(data, six.text_type) or not hasattr(data, "__iter__"):
+ self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
- self.fail('empty')
+ self.fail("empty")
- return [
- self.child_relation.to_internal_value(item)
- for item in data
- ]
+ return [self.child_relation.to_internal_value(item) for item in data]
def get_attribute(self, instance):
# Can't have any relationships if not created
- if hasattr(instance, 'pk') and instance.pk is None:
+ if hasattr(instance, "pk") and instance.pk is None:
return []
relationship = get_attribute(instance, self.source_attrs)
- return relationship.all() if hasattr(relationship, 'all') else relationship
+ return relationship.all() if hasattr(relationship, "all") else relationship
def to_representation(self, iterable):
- return [
- self.child_relation.to_representation(value)
- for value in iterable
- ]
+ return [self.child_relation.to_representation(value) for value in iterable]
def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff)
@@ -557,5 +570,5 @@ class ManyRelatedField(Field):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
- cutoff_text=self.html_cutoff_text
+ cutoff_text=self.html_cutoff_text,
)
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index f043e6327..176c2b2d2 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -25,8 +25,13 @@ from django.utils.six.moves.urllib import parse as urlparse
from rest_framework import VERSION, exceptions, serializers, status
from rest_framework.compat import (
- INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema,
- pygments_css, yaml
+ INDENT_SEPARATORS,
+ LONG_SEPARATORS,
+ SHORT_SEPARATORS,
+ coreapi,
+ coreschema,
+ pygments_css,
+ yaml,
)
from rest_framework.exceptions import ParseError
from rest_framework.request import is_form_media_type, override_method
@@ -45,21 +50,23 @@ class BaseRenderer(object):
All renderers should extend this class, setting the `media_type`
and `format` attributes, and override the `.render()` method.
"""
+
media_type = None
format = None
- charset = 'utf-8'
- render_style = 'text'
+ charset = "utf-8"
+ render_style = "text"
def render(self, data, accepted_media_type=None, renderer_context=None):
- raise NotImplementedError('Renderer class requires .render() to be implemented')
+ raise NotImplementedError("Renderer class requires .render() to be implemented")
class JSONRenderer(BaseRenderer):
"""
Renderer which serializes to JSON.
"""
- media_type = 'application/json'
- format = 'json'
+
+ media_type = "application/json"
+ format = "json"
encoder_class = encoders.JSONEncoder
ensure_ascii = not api_settings.UNICODE_JSON
compact = api_settings.COMPACT_JSON
@@ -76,15 +83,15 @@ class JSONRenderer(BaseRenderer):
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
# Note that we coerce `indent=0` into `indent=None`.
- base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
+ base_media_type, params = parse_header(accepted_media_type.encode("ascii"))
try:
- return zero_as_none(max(min(int(params['indent']), 8), 0))
+ return zero_as_none(max(min(int(params["indent"]), 8), 0))
except (KeyError, ValueError, TypeError):
pass
# If 'indent' is provided in the context, then pretty print the result.
# E.g. If we're being called by the BrowsableAPIRenderer.
- return renderer_context.get('indent', None)
+ return renderer_context.get("indent", None)
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -102,9 +109,12 @@ class JSONRenderer(BaseRenderer):
separators = INDENT_SEPARATORS
ret = json.dumps(
- data, cls=self.encoder_class,
- indent=indent, ensure_ascii=self.ensure_ascii,
- allow_nan=not self.strict, separators=separators
+ data,
+ cls=self.encoder_class,
+ indent=indent,
+ ensure_ascii=self.ensure_ascii,
+ allow_nan=not self.strict,
+ separators=separators,
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
@@ -116,8 +126,8 @@ class JSONRenderer(BaseRenderer):
# that is a strict javascript subset. If bytes were returned
# by json.dumps() then we don't have these characters in any case.
# See: http://timelessrepo.com/json-isnt-a-javascript-subset
- ret = ret.replace('\u2028', '\\u2028').replace('\u2029', '\\u2029')
- return bytes(ret.encode('utf-8'))
+ ret = ret.replace("\u2028", "\\u2028").replace("\u2029", "\\u2029")
+ return bytes(ret.encode("utf-8"))
return ret
@@ -140,14 +150,12 @@ class TemplateHTMLRenderer(BaseRenderer):
For pre-rendered HTML, see StaticHTMLRenderer.
"""
- media_type = 'text/html'
- format = 'html'
+
+ media_type = "text/html"
+ format = "html"
template_name = None
- exception_template_names = [
- '%(status_code)s.html',
- 'api_exception.html'
- ]
- charset = 'utf-8'
+ exception_template_names = ["%(status_code)s.html", "api_exception.html"]
+ charset = "utf-8"
def render(self, data, accepted_media_type=None, renderer_context=None):
"""
@@ -160,9 +168,9 @@ class TemplateHTMLRenderer(BaseRenderer):
3. The return result of calling view.get_template_names().
"""
renderer_context = renderer_context or {}
- view = renderer_context['view']
- request = renderer_context['request']
- response = renderer_context['response']
+ view = renderer_context["view"]
+ request = renderer_context["request"]
+ response = renderer_context["response"]
if response.exception:
template = self.get_exception_template(response)
@@ -170,7 +178,7 @@ class TemplateHTMLRenderer(BaseRenderer):
template_names = self.get_template_names(response, view)
template = self.resolve_template(template_names)
- if hasattr(self, 'resolve_context'):
+ if hasattr(self, "resolve_context"):
# Fallback for older versions.
context = self.resolve_context(data, request, response)
else:
@@ -181,9 +189,9 @@ class TemplateHTMLRenderer(BaseRenderer):
return loader.select_template(template_names)
def get_template_context(self, data, renderer_context):
- response = renderer_context['response']
+ response = renderer_context["response"]
if response.exception:
- data['status_code'] = response.status_code
+ data["status_code"] = response.status_code
return data
def get_template_names(self, response, view):
@@ -191,25 +199,27 @@ class TemplateHTMLRenderer(BaseRenderer):
return [response.template_name]
elif self.template_name:
return [self.template_name]
- elif hasattr(view, 'get_template_names'):
+ elif hasattr(view, "get_template_names"):
return view.get_template_names()
- elif hasattr(view, 'template_name'):
+ elif hasattr(view, "template_name"):
return [view.template_name]
raise ImproperlyConfigured(
- 'Returned a template response with no `template_name` attribute set on either the view or response'
+ "Returned a template response with no `template_name` attribute set on either the view or response"
)
def get_exception_template(self, response):
- template_names = [name % {'status_code': response.status_code}
- for name in self.exception_template_names]
+ template_names = [
+ name % {"status_code": response.status_code}
+ for name in self.exception_template_names
+ ]
try:
# Try to find an appropriate error template
return self.resolve_template(template_names)
except Exception:
# Fall back to using eg '404 Not Found'
- body = '%d %s' % (response.status_code, response.status_text.title())
- template = engines['django'].from_string(body)
+ body = "%d %s" % (response.status_code, response.status_text.title())
+ template = engines["django"].from_string(body)
return template
@@ -227,18 +237,19 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
For template rendered HTML, see TemplateHTMLRenderer.
"""
- media_type = 'text/html'
- format = 'html'
- charset = 'utf-8'
+
+ media_type = "text/html"
+ format = "html"
+ charset = "utf-8"
def render(self, data, accepted_media_type=None, renderer_context=None):
renderer_context = renderer_context or {}
- response = renderer_context.get('response')
+ response = renderer_context.get("response")
if response and response.exception:
- request = renderer_context['request']
+ request = renderer_context["request"]
template = self.get_exception_template(response)
- if hasattr(self, 'resolve_context'):
+ if hasattr(self, "resolve_context"):
context = self.resolve_context(data, request, response)
else:
context = self.get_template_context(data, renderer_context)
@@ -258,107 +269,96 @@ class HTMLFormRenderer(BaseRenderer):
Note that rendering of field and form errors is not currently supported.
"""
- media_type = 'text/html'
- format = 'form'
- charset = 'utf-8'
- template_pack = 'rest_framework/vertical/'
- base_template = 'form.html'
- default_style = ClassLookupDict({
- serializers.Field: {
- 'base_template': 'input.html',
- 'input_type': 'text'
- },
- serializers.EmailField: {
- 'base_template': 'input.html',
- 'input_type': 'email'
- },
- serializers.URLField: {
- 'base_template': 'input.html',
- 'input_type': 'url'
- },
- serializers.IntegerField: {
- 'base_template': 'input.html',
- 'input_type': 'number'
- },
- serializers.FloatField: {
- 'base_template': 'input.html',
- 'input_type': 'number'
- },
- serializers.DateTimeField: {
- 'base_template': 'input.html',
- 'input_type': 'datetime-local'
- },
- serializers.DateField: {
- 'base_template': 'input.html',
- 'input_type': 'date'
- },
- serializers.TimeField: {
- 'base_template': 'input.html',
- 'input_type': 'time'
- },
- serializers.FileField: {
- 'base_template': 'input.html',
- 'input_type': 'file'
- },
- serializers.BooleanField: {
- 'base_template': 'checkbox.html'
- },
- serializers.ChoiceField: {
- 'base_template': 'select.html', # Also valid: 'radio.html'
- },
- serializers.MultipleChoiceField: {
- 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html'
- },
- serializers.RelatedField: {
- 'base_template': 'select.html', # Also valid: 'radio.html'
- },
- serializers.ManyRelatedField: {
- 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html'
- },
- serializers.Serializer: {
- 'base_template': 'fieldset.html'
- },
- serializers.ListSerializer: {
- 'base_template': 'list_fieldset.html'
- },
- serializers.ListField: {
- 'base_template': 'list_field.html'
- },
- serializers.DictField: {
- 'base_template': 'dict_field.html'
- },
- serializers.FilePathField: {
- 'base_template': 'select.html',
- },
- serializers.JSONField: {
- 'base_template': 'textarea.html',
- },
- })
+ media_type = "text/html"
+ format = "form"
+ charset = "utf-8"
+ template_pack = "rest_framework/vertical/"
+ base_template = "form.html"
+
+ default_style = ClassLookupDict(
+ {
+ serializers.Field: {"base_template": "input.html", "input_type": "text"},
+ serializers.EmailField: {
+ "base_template": "input.html",
+ "input_type": "email",
+ },
+ serializers.URLField: {"base_template": "input.html", "input_type": "url"},
+ serializers.IntegerField: {
+ "base_template": "input.html",
+ "input_type": "number",
+ },
+ serializers.FloatField: {
+ "base_template": "input.html",
+ "input_type": "number",
+ },
+ serializers.DateTimeField: {
+ "base_template": "input.html",
+ "input_type": "datetime-local",
+ },
+ serializers.DateField: {
+ "base_template": "input.html",
+ "input_type": "date",
+ },
+ serializers.TimeField: {
+ "base_template": "input.html",
+ "input_type": "time",
+ },
+ serializers.FileField: {
+ "base_template": "input.html",
+ "input_type": "file",
+ },
+ serializers.BooleanField: {"base_template": "checkbox.html"},
+ serializers.ChoiceField: {
+ "base_template": "select.html" # Also valid: 'radio.html'
+ },
+ serializers.MultipleChoiceField: {
+ "base_template": "select_multiple.html" # Also valid: 'checkbox_multiple.html'
+ },
+ serializers.RelatedField: {
+ "base_template": "select.html" # Also valid: 'radio.html'
+ },
+ serializers.ManyRelatedField: {
+ "base_template": "select_multiple.html" # Also valid: 'checkbox_multiple.html'
+ },
+ serializers.Serializer: {"base_template": "fieldset.html"},
+ serializers.ListSerializer: {"base_template": "list_fieldset.html"},
+ serializers.ListField: {"base_template": "list_field.html"},
+ serializers.DictField: {"base_template": "dict_field.html"},
+ serializers.FilePathField: {"base_template": "select.html"},
+ serializers.JSONField: {"base_template": "textarea.html"},
+ }
+ )
def render_field(self, field, parent_style):
if isinstance(field._field, serializers.HiddenField):
- return ''
+ return ""
style = dict(self.default_style[field])
style.update(field.style)
- if 'template_pack' not in style:
- style['template_pack'] = parent_style.get('template_pack', self.template_pack)
- style['renderer'] = self
+ if "template_pack" not in style:
+ style["template_pack"] = parent_style.get(
+ "template_pack", self.template_pack
+ )
+ style["renderer"] = self
# Get a clone of the field with text-only value representation.
field = field.as_form_field()
- if style.get('input_type') == 'datetime-local' and isinstance(field.value, six.text_type):
- field.value = field.value.rstrip('Z')
+ if style.get("input_type") == "datetime-local" and isinstance(
+ field.value, six.text_type
+ ):
+ field.value = field.value.rstrip("Z")
- if 'template' in style:
- template_name = style['template']
+ if "template" in style:
+ template_name = style["template"]
else:
- template_name = style['template_pack'].strip('/') + '/' + style['base_template']
+ template_name = (
+ style["template_pack"].strip("/") + "/" + style["base_template"]
+ )
template = loader.get_template(template_name)
- context = {'field': field, 'style': style}
+ context = {"field": field, "style": style}
return template.render(context)
def render(self, data, accepted_media_type=None, renderer_context=None):
@@ -368,18 +368,15 @@ class HTMLFormRenderer(BaseRenderer):
renderer_context = renderer_context or {}
form = data.serializer
- style = renderer_context.get('style', {})
- if 'template_pack' not in style:
- style['template_pack'] = self.template_pack
- style['renderer'] = self
+ style = renderer_context.get("style", {})
+ if "template_pack" not in style:
+ style["template_pack"] = self.template_pack
+ style["renderer"] = self
- template_pack = style['template_pack'].strip('/')
- template_name = template_pack + '/' + self.base_template
+ template_pack = style["template_pack"].strip("/")
+ template_name = template_pack + "/" + self.base_template
template = loader.get_template(template_name)
- context = {
- 'form': form,
- 'style': style
- }
+ context = {"form": form, "style": style}
return template.render(context)
@@ -387,12 +384,13 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
HTML renderer used to self-document the API.
"""
- media_type = 'text/html'
- format = 'api'
- template = 'rest_framework/api.html'
- filter_template = 'rest_framework/filters/base.html'
- code_style = 'emacs'
- charset = 'utf-8'
+
+ media_type = "text/html"
+ format = "api"
+ template = "rest_framework/api.html"
+ filter_template = "rest_framework/filters/base.html"
+ code_style = "emacs"
+ charset = "utf-8"
form_renderer_class = HTMLFormRenderer
def get_default_renderer(self, view):
@@ -400,10 +398,16 @@ class BrowsableAPIRenderer(BaseRenderer):
Return an instance of the first valid renderer.
(Don't use another documenting renderer.)
"""
- renderers = [renderer for renderer in view.renderer_classes
- if not issubclass(renderer, BrowsableAPIRenderer)]
- non_template_renderers = [renderer for renderer in renderers
- if not hasattr(renderer, 'get_template_names')]
+ renderers = [
+ renderer
+ for renderer in view.renderer_classes
+ if not issubclass(renderer, BrowsableAPIRenderer)
+ ]
+ non_template_renderers = [
+ renderer
+ for renderer in renderers
+ if not hasattr(renderer, "get_template_names")
+ ]
if not renderers:
return None
@@ -411,23 +415,23 @@ class BrowsableAPIRenderer(BaseRenderer):
return non_template_renderers[0]()
return renderers[0]()
- def get_content(self, renderer, data,
- accepted_media_type, renderer_context):
+ def get_content(self, renderer, data, accepted_media_type, renderer_context):
"""
Get the content as if it had been rendered by the default
non-documenting renderer.
"""
if not renderer:
- return '[No renderers were found]'
+ return "[No renderers were found]"
- renderer_context['indent'] = 4
+ renderer_context["indent"] = 4
content = renderer.render(data, accepted_media_type, renderer_context)
- render_style = getattr(renderer, 'render_style', 'text')
- assert render_style in ['text', 'binary'], 'Expected .render_style ' \
- '"text" or "binary", but got "%s"' % render_style
- if render_style == 'binary':
- return '[%d bytes of binary content]' % len(content)
+ render_style = getattr(renderer, "render_style", "text")
+ assert render_style in ["text", "binary"], (
+ "Expected .render_style " '"text" or "binary", but got "%s"' % render_style
+ )
+ if render_style == "binary":
+ return "[%d bytes of binary content]" % len(content)
return content
@@ -446,11 +450,13 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions
return True
- def _get_serializer(self, serializer_class, view_instance, request, *args, **kwargs):
- kwargs['context'] = {
- 'request': request,
- 'format': self.format,
- 'view': view_instance
+ def _get_serializer(
+ self, serializer_class, view_instance, request, *args, **kwargs
+ ):
+ kwargs["context"] = {
+ "request": request,
+ "format": self.format,
+ "view": view_instance,
}
return serializer_class(*args, **kwargs)
@@ -462,9 +468,9 @@ class BrowsableAPIRenderer(BaseRenderer):
In the absence of the View having an associated form then return None.
"""
# See issue #2089 for refactoring this.
- serializer = getattr(data, 'serializer', None)
- if serializer and not getattr(serializer, 'many', False):
- instance = getattr(serializer, 'instance', None)
+ serializer = getattr(data, "serializer", None)
+ if serializer and not getattr(serializer, "many", False):
+ instance = getattr(serializer, "instance", None)
if isinstance(instance, Page):
instance = None
else:
@@ -475,7 +481,7 @@ class BrowsableAPIRenderer(BaseRenderer):
# serializer instance, rather than dynamically creating a new one.
if request.method == method and serializer is not None:
try:
- kwargs = {'data': request.data}
+ kwargs = {"data": request.data}
except ParseError:
kwargs = {}
existing_serializer = serializer
@@ -487,15 +493,14 @@ class BrowsableAPIRenderer(BaseRenderer):
if not self.show_form_for_method(view, method, request, instance):
return
- if method in ('DELETE', 'OPTIONS'):
+ if method in ("DELETE", "OPTIONS"):
return True # Don't actually need to return a form
- has_serializer = getattr(view, 'get_serializer', None)
- has_serializer_class = getattr(view, 'serializer_class', None)
+ has_serializer = getattr(view, "get_serializer", None)
+ has_serializer_class = getattr(view, "serializer_class", None)
- if (
- (not has_serializer and not has_serializer_class) or
- not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
+ if (not has_serializer and not has_serializer_class) or not any(
+ is_form_media_type(parser.media_type) for parser in view.parser_classes
):
return
@@ -506,30 +511,36 @@ class BrowsableAPIRenderer(BaseRenderer):
pass
if has_serializer:
- if method in ('PUT', 'PATCH'):
+ if method in ("PUT", "PATCH"):
serializer = view.get_serializer(instance=instance, **kwargs)
else:
serializer = view.get_serializer(**kwargs)
else:
# at this point we must have a serializer_class
- if method in ('PUT', 'PATCH'):
- serializer = self._get_serializer(view.serializer_class, view,
- request, instance=instance, **kwargs)
+ if method in ("PUT", "PATCH"):
+ serializer = self._get_serializer(
+ view.serializer_class,
+ view,
+ request,
+ instance=instance,
+ **kwargs
+ )
else:
- serializer = self._get_serializer(view.serializer_class, view,
- request, **kwargs)
+ serializer = self._get_serializer(
+ view.serializer_class, view, request, **kwargs
+ )
return self.render_form_for_serializer(serializer)
def render_form_for_serializer(self, serializer):
- if hasattr(serializer, 'initial_data'):
+ if hasattr(serializer, "initial_data"):
serializer.is_valid()
form_renderer = self.form_renderer_class()
return form_renderer.render(
serializer.data,
self.accepted_media_type,
- {'style': {'template_pack': 'rest_framework/horizontal'}}
+ {"style": {"template_pack": "rest_framework/horizontal"}},
)
def get_raw_data_form(self, data, view, method, request):
@@ -539,9 +550,9 @@ class BrowsableAPIRenderer(BaseRenderer):
(Which are typically application/x-www-form-urlencoded)
"""
# See issue #2089 for refactoring this.
- serializer = getattr(data, 'serializer', None)
- if serializer and not getattr(serializer, 'many', False):
- instance = getattr(serializer, 'instance', None)
+ serializer = getattr(data, "serializer", None)
+ if serializer and not getattr(serializer, "many", False):
+ instance = getattr(serializer, "instance", None)
if isinstance(instance, Page):
instance = None
else:
@@ -554,12 +565,12 @@ class BrowsableAPIRenderer(BaseRenderer):
# If possible, serialize the initial content for the generic form
default_parser = view.parser_classes[0]
- renderer_class = getattr(default_parser, 'renderer_class', None)
- if hasattr(view, 'get_serializer') and renderer_class:
+ renderer_class = getattr(default_parser, "renderer_class", None)
+ if hasattr(view, "get_serializer") and renderer_class:
# View has a serializer defined and parser class has a
# corresponding renderer that can be used to render the data.
- if method in ('PUT', 'PATCH'):
+ if method in ("PUT", "PATCH"):
serializer = view.get_serializer(instance=instance)
else:
serializer = view.get_serializer()
@@ -568,7 +579,7 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = renderer_class()
accepted = self.accepted_media_type
context = self.renderer_context.copy()
- context['indent'] = 4
+ context["indent"] = 4
# strip HiddenField from output
data = serializer.data.copy()
@@ -577,7 +588,7 @@ class BrowsableAPIRenderer(BaseRenderer):
data.pop(name, None)
content = renderer.render(data, accepted, context)
# Renders returns bytes, but CharField expects a str.
- content = content.decode('utf-8')
+ content = content.decode("utf-8")
else:
content = None
@@ -589,16 +600,16 @@ class BrowsableAPIRenderer(BaseRenderer):
class GenericContentForm(forms.Form):
_content_type = forms.ChoiceField(
- label='Media type',
+ label="Media type",
choices=choices,
initial=initial,
- widget=forms.Select(attrs={'data-override': 'content-type'})
+ widget=forms.Select(attrs={"data-override": "content-type"}),
)
_content = forms.CharField(
- label='Content',
- widget=forms.Textarea(attrs={'data-override': 'content'}),
+ label="Content",
+ widget=forms.Textarea(attrs={"data-override": "content"}),
initial=content,
- required=False
+ required=False,
)
return GenericContentForm()
@@ -608,23 +619,23 @@ class BrowsableAPIRenderer(BaseRenderer):
def get_description(self, view, status_code):
if status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN):
- return ''
+ return ""
return view.get_view_description(html=True)
def get_breadcrumbs(self, request):
return get_breadcrumbs(request.path, request)
def get_extra_actions(self, view):
- if hasattr(view, 'get_extra_action_url_map'):
+ if hasattr(view, "get_extra_action_url_map"):
return view.get_extra_action_url_map()
return None
def get_filter_form(self, data, view, request):
- if not hasattr(view, 'get_queryset') or not hasattr(view, 'filter_backends'):
+ if not hasattr(view, "get_queryset") or not hasattr(view, "filter_backends"):
return
# Infer if this is a list view or not.
- paginator = getattr(view, 'paginator', None)
+ paginator = getattr(view, "paginator", None)
if isinstance(data, list):
pass
elif paginator is not None and data is not None:
@@ -638,7 +649,7 @@ class BrowsableAPIRenderer(BaseRenderer):
queryset = view.get_queryset()
elements = []
for backend in view.filter_backends:
- if hasattr(backend, 'to_html'):
+ if hasattr(backend, "to_html"):
html = backend().to_html(request, queryset, view)
if html:
elements.append(html)
@@ -647,78 +658,76 @@ class BrowsableAPIRenderer(BaseRenderer):
return
template = loader.get_template(self.filter_template)
- context = {'elements': elements}
+ context = {"elements": elements}
return template.render(context)
def get_context(self, data, accepted_media_type, renderer_context):
"""
Returns the context used to render.
"""
- view = renderer_context['view']
- request = renderer_context['request']
- response = renderer_context['response']
+ view = renderer_context["view"]
+ request = renderer_context["request"]
+ response = renderer_context["response"]
renderer = self.get_default_renderer(view)
- raw_data_post_form = self.get_raw_data_form(data, view, 'POST', request)
- raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request)
- raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request)
+ raw_data_post_form = self.get_raw_data_form(data, view, "POST", request)
+ raw_data_put_form = self.get_raw_data_form(data, view, "PUT", request)
+ raw_data_patch_form = self.get_raw_data_form(data, view, "PATCH", request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
response_headers = OrderedDict(sorted(response.items()))
- renderer_content_type = ''
+ renderer_content_type = ""
if renderer:
- renderer_content_type = '%s' % renderer.media_type
+ renderer_content_type = "%s" % renderer.media_type
if renderer.charset:
- renderer_content_type += ' ;%s' % renderer.charset
- response_headers['Content-Type'] = renderer_content_type
+ renderer_content_type += " ;%s" % renderer.charset
+ response_headers["Content-Type"] = renderer_content_type
- if getattr(view, 'paginator', None) and view.paginator.display_page_controls:
+ if getattr(view, "paginator", None) and view.paginator.display_page_controls:
paginator = view.paginator
else:
paginator = None
csrf_cookie_name = settings.CSRF_COOKIE_NAME
csrf_header_name = settings.CSRF_HEADER_NAME
- if csrf_header_name.startswith('HTTP_'):
+ if csrf_header_name.startswith("HTTP_"):
csrf_header_name = csrf_header_name[5:]
- csrf_header_name = csrf_header_name.replace('_', '-')
+ csrf_header_name = csrf_header_name.replace("_", "-")
context = {
- 'content': self.get_content(renderer, data, accepted_media_type, renderer_context),
- 'code_style': pygments_css(self.code_style),
- 'view': view,
- 'request': request,
- 'response': response,
- 'user': request.user,
- 'description': self.get_description(view, response.status_code),
- 'name': self.get_name(view),
- 'version': VERSION,
- 'paginator': paginator,
- 'breadcrumblist': self.get_breadcrumbs(request),
- 'allowed_methods': view.allowed_methods,
- 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
- 'response_headers': response_headers,
-
- 'put_form': self.get_rendered_html_form(data, view, 'PUT', request),
- 'post_form': self.get_rendered_html_form(data, view, 'POST', request),
- 'delete_form': self.get_rendered_html_form(data, view, 'DELETE', request),
- 'options_form': self.get_rendered_html_form(data, view, 'OPTIONS', request),
-
- 'extra_actions': self.get_extra_actions(view),
-
- 'filter_form': self.get_filter_form(data, view, request),
-
- 'raw_data_put_form': raw_data_put_form,
- 'raw_data_post_form': raw_data_post_form,
- 'raw_data_patch_form': raw_data_patch_form,
- 'raw_data_put_or_patch_form': raw_data_put_or_patch_form,
-
- 'display_edit_forms': bool(response.status_code != 403),
-
- 'api_settings': api_settings,
- 'csrf_cookie_name': csrf_cookie_name,
- 'csrf_header_name': csrf_header_name
+ "content": self.get_content(
+ renderer, data, accepted_media_type, renderer_context
+ ),
+ "code_style": pygments_css(self.code_style),
+ "view": view,
+ "request": request,
+ "response": response,
+ "user": request.user,
+ "description": self.get_description(view, response.status_code),
+ "name": self.get_name(view),
+ "version": VERSION,
+ "paginator": paginator,
+ "breadcrumblist": self.get_breadcrumbs(request),
+ "allowed_methods": view.allowed_methods,
+ "available_formats": [
+ renderer_cls.format for renderer_cls in view.renderer_classes
+ ],
+ "response_headers": response_headers,
+ "put_form": self.get_rendered_html_form(data, view, "PUT", request),
+ "post_form": self.get_rendered_html_form(data, view, "POST", request),
+ "delete_form": self.get_rendered_html_form(data, view, "DELETE", request),
+ "options_form": self.get_rendered_html_form(data, view, "OPTIONS", request),
+ "extra_actions": self.get_extra_actions(view),
+ "filter_form": self.get_filter_form(data, view, request),
+ "raw_data_put_form": raw_data_put_form,
+ "raw_data_post_form": raw_data_post_form,
+ "raw_data_patch_form": raw_data_patch_form,
+ "raw_data_put_or_patch_form": raw_data_put_or_patch_form,
+ "display_edit_forms": bool(response.status_code != 403),
+ "api_settings": api_settings,
+ "csrf_cookie_name": csrf_cookie_name,
+ "csrf_header_name": csrf_header_name,
}
return context
@@ -726,17 +735,17 @@ class BrowsableAPIRenderer(BaseRenderer):
"""
Render the HTML for the browsable API representation.
"""
- self.accepted_media_type = accepted_media_type or ''
+ self.accepted_media_type = accepted_media_type or ""
self.renderer_context = renderer_context or {}
template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context)
- ret = template.render(context, request=renderer_context['request'])
+ ret = template.render(context, request=renderer_context["request"])
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
# the normal deletion response code in the output)
- response = renderer_context['response']
+ response = renderer_context["response"]
if response.status_code == status.HTTP_204_NO_CONTENT:
response.status_code = status.HTTP_200_OK
@@ -744,46 +753,50 @@ class BrowsableAPIRenderer(BaseRenderer):
class AdminRenderer(BrowsableAPIRenderer):
- template = 'rest_framework/admin.html'
- format = 'admin'
+ template = "rest_framework/admin.html"
+ format = "admin"
def render(self, data, accepted_media_type=None, renderer_context=None):
- self.accepted_media_type = accepted_media_type or ''
+ self.accepted_media_type = accepted_media_type or ""
self.renderer_context = renderer_context or {}
- response = renderer_context['response']
- request = renderer_context['request']
- view = self.renderer_context['view']
+ response = renderer_context["response"]
+ request = renderer_context["request"]
+ view = self.renderer_context["view"]
if response.status_code == status.HTTP_400_BAD_REQUEST:
# Errors still need to display the list or detail information.
# The only way we can get at that is to simulate a GET request.
- self.error_form = self.get_rendered_html_form(data, view, request.method, request)
- self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(request.method, 'Errors')
+ self.error_form = self.get_rendered_html_form(
+ data, view, request.method, request
+ )
+ self.error_title = {"POST": "Create", "PUT": "Edit"}.get(
+ request.method, "Errors"
+ )
- with override_method(view, request, 'GET') as request:
+ with override_method(view, request, "GET") as request:
response = view.get(request, *view.args, **view.kwargs)
data = response.data
template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context)
- ret = template.render(context, request=renderer_context['request'])
+ ret = template.render(context, request=renderer_context["request"])
# Creation and deletion should use redirects in the admin style.
- if response.status_code == status.HTTP_201_CREATED and 'Location' in response:
+ if response.status_code == status.HTTP_201_CREATED and "Location" in response:
response.status_code = status.HTTP_303_SEE_OTHER
- response['Location'] = request.build_absolute_uri()
- ret = ''
+ response["Location"] = request.build_absolute_uri()
+ ret = ""
if response.status_code == status.HTTP_204_NO_CONTENT:
response.status_code = status.HTTP_303_SEE_OTHER
try:
# Attempt to get the parent breadcrumb URL.
- response['Location'] = self.get_breadcrumbs(request)[-2][1]
+ response["Location"] = self.get_breadcrumbs(request)[-2][1]
except KeyError:
# Otherwise reload current URL to get a 'Not Found' page.
- response['Location'] = request.full_path
- ret = ''
+ response["Location"] = request.full_path
+ ret = ""
return ret
@@ -795,7 +808,7 @@ class AdminRenderer(BrowsableAPIRenderer):
data, accepted_media_type, renderer_context
)
- paginator = getattr(context['view'], 'paginator', None)
+ paginator = getattr(context["view"], "paginator", None)
if paginator is not None and data is not None:
try:
results = paginator.get_results(data)
@@ -806,29 +819,29 @@ class AdminRenderer(BrowsableAPIRenderer):
if results is None:
header = {}
- style = 'detail'
+ style = "detail"
elif isinstance(results, list):
header = results[0] if results else {}
- style = 'list'
+ style = "list"
else:
header = results
- style = 'detail'
+ style = "detail"
- columns = [key for key in header if key != 'url']
- details = [key for key in header if key != 'url']
+ columns = [key for key in header if key != "url"]
+ details = [key for key in header if key != "url"]
- if isinstance(results, list) and 'view' in renderer_context:
+ if isinstance(results, list) and "view" in renderer_context:
for result in results:
- url = self.get_result_url(result, context['view'])
+ url = self.get_result_url(result, context["view"])
if url is not None:
- result.setdefault('url', url)
+ result.setdefault("url", url)
- context['style'] = style
- context['columns'] = columns
- context['details'] = details
- context['results'] = results
- context['error_form'] = getattr(self, 'error_form', None)
- context['error_title'] = getattr(self, 'error_title', None)
+ context["style"] = style
+ context["columns"] = columns
+ context["details"] = details
+ context["results"] = results
+ context["error_form"] = getattr(self, "error_form", None)
+ context["error_title"] = getattr(self, "error_title", None)
return context
def get_result_url(self, result, view):
@@ -838,79 +851,82 @@ class AdminRenderer(BrowsableAPIRenderer):
This only works with views that are generic-like (has `.lookup_field`)
and viewset-like (has `.basename` / `.reverse_action()`).
"""
- if not hasattr(view, 'reverse_action') or \
- not hasattr(view, 'lookup_field'):
+ if not hasattr(view, "reverse_action") or not hasattr(view, "lookup_field"):
return
lookup_field = view.lookup_field
- lookup_url_kwarg = getattr(view, 'lookup_url_kwarg', None) or lookup_field
+ lookup_url_kwarg = getattr(view, "lookup_url_kwarg", None) or lookup_field
try:
kwargs = {lookup_url_kwarg: result[lookup_field]}
- return view.reverse_action('detail', kwargs=kwargs)
+ return view.reverse_action("detail", kwargs=kwargs)
except (KeyError, NoReverseMatch):
return
class DocumentationRenderer(BaseRenderer):
- media_type = 'text/html'
- format = 'html'
- charset = 'utf-8'
- template = 'rest_framework/docs/index.html'
- error_template = 'rest_framework/docs/error.html'
- code_style = 'emacs'
- languages = ['shell', 'javascript', 'python']
+ media_type = "text/html"
+ format = "html"
+ charset = "utf-8"
+ template = "rest_framework/docs/index.html"
+ error_template = "rest_framework/docs/error.html"
+ code_style = "emacs"
+ languages = ["shell", "javascript", "python"]
def get_context(self, data, request):
return {
- 'document': data,
- 'langs': self.languages,
- 'lang_htmls': ["rest_framework/docs/langs/%s.html" % l for l in self.languages],
- 'lang_intro_htmls': ["rest_framework/docs/langs/%s-intro.html" % l for l in self.languages],
- 'code_style': pygments_css(self.code_style),
- 'request': request
+ "document": data,
+ "langs": self.languages,
+ "lang_htmls": [
+ "rest_framework/docs/langs/%s.html" % l for l in self.languages
+ ],
+ "lang_intro_htmls": [
+ "rest_framework/docs/langs/%s-intro.html" % l for l in self.languages
+ ],
+ "code_style": pygments_css(self.code_style),
+ "request": request,
}
def render(self, data, accepted_media_type=None, renderer_context=None):
if isinstance(data, coreapi.Document):
template = loader.get_template(self.template)
- context = self.get_context(data, renderer_context['request'])
- return template.render(context, request=renderer_context['request'])
+ context = self.get_context(data, renderer_context["request"])
+ return template.render(context, request=renderer_context["request"])
else:
template = loader.get_template(self.error_template)
context = {
"data": data,
- "request": renderer_context['request'],
- "response": renderer_context['response'],
+ "request": renderer_context["request"],
+ "response": renderer_context["response"],
"debug": settings.DEBUG,
}
- return template.render(context, request=renderer_context['request'])
+ return template.render(context, request=renderer_context["request"])
class SchemaJSRenderer(BaseRenderer):
- media_type = 'application/javascript'
- format = 'javascript'
- charset = 'utf-8'
- template = 'rest_framework/schema.js'
+ media_type = "application/javascript"
+ format = "javascript"
+ charset = "utf-8"
+ template = "rest_framework/schema.js"
def render(self, data, accepted_media_type=None, renderer_context=None):
codec = coreapi.codecs.CoreJSONCodec()
- schema = base64.b64encode(codec.encode(data)).decode('ascii')
+ schema = base64.b64encode(codec.encode(data)).decode("ascii")
template = loader.get_template(self.template)
- context = {'schema': mark_safe(schema)}
- request = renderer_context['request']
+ context = {"schema": mark_safe(schema)}
+ request = renderer_context["request"]
return template.render(context, request=request)
class MultiPartRenderer(BaseRenderer):
- media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
- format = 'multipart'
- charset = 'utf-8'
- BOUNDARY = 'BoUnDaRyStRiNg'
+ media_type = "multipart/form-data; boundary=BoUnDaRyStRiNg"
+ format = "multipart"
+ charset = "utf-8"
+ BOUNDARY = "BoUnDaRyStRiNg"
def render(self, data, accepted_media_type=None, renderer_context=None):
- if hasattr(data, 'items'):
+ if hasattr(data, "items"):
for key, value in data.items():
assert not isinstance(value, dict), (
"Test data contained a dictionary value for key '%s', "
@@ -922,15 +938,15 @@ class MultiPartRenderer(BaseRenderer):
class CoreJSONRenderer(BaseRenderer):
- media_type = 'application/coreapi+json'
+ media_type = "application/coreapi+json"
charset = None
- format = 'corejson'
+ format = "corejson"
def __init__(self):
- assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.'
+ assert coreapi, "Using CoreJSONRenderer, but `coreapi` is not installed."
def render(self, data, media_type=None, renderer_context=None):
- indent = bool(renderer_context.get('indent', 0))
+ indent = bool(renderer_context.get("indent", 0))
codec = coreapi.codecs.CoreJSONCodec()
return codec.dump(data, indent=indent)
@@ -938,38 +954,35 @@ class CoreJSONRenderer(BaseRenderer):
class _BaseOpenAPIRenderer:
def get_schema(self, instance):
CLASS_TO_TYPENAME = {
- coreschema.Object: 'object',
- coreschema.Array: 'array',
- coreschema.Number: 'number',
- coreschema.Integer: 'integer',
- coreschema.String: 'string',
- coreschema.Boolean: 'boolean',
+ coreschema.Object: "object",
+ coreschema.Array: "array",
+ coreschema.Number: "number",
+ coreschema.Integer: "integer",
+ coreschema.String: "string",
+ coreschema.Boolean: "boolean",
}
schema = {}
if instance.__class__ in CLASS_TO_TYPENAME:
- schema['type'] = CLASS_TO_TYPENAME[instance.__class__]
- schema['title'] = instance.title
- schema['description'] = instance.description
- if hasattr(instance, 'enum'):
- schema['enum'] = instance.enum
+ schema["type"] = CLASS_TO_TYPENAME[instance.__class__]
+ schema["title"] = instance.title
+ schema["description"] = instance.description
+ if hasattr(instance, "enum"):
+ schema["enum"] = instance.enum
return schema
def get_parameters(self, link):
parameters = []
for field in link.fields:
- if field.location not in ['path', 'query']:
+ if field.location not in ["path", "query"]:
continue
- parameter = {
- 'name': field.name,
- 'in': field.location,
- }
+ parameter = {"name": field.name, "in": field.location}
if field.required:
- parameter['required'] = True
+ parameter["required"] = True
if field.description:
- parameter['description'] = field.description
+ parameter["description"] = field.description
if field.schema:
- parameter['schema'] = self.get_schema(field.schema)
+ parameter["schema"] = self.get_schema(field.schema)
parameters.append(parameter)
return parameters
@@ -977,17 +990,15 @@ class _BaseOpenAPIRenderer:
operation_id = "%s_%s" % (tag, name) if tag else name
parameters = self.get_parameters(link)
- operation = {
- 'operationId': operation_id,
- }
+ operation = {"operationId": operation_id}
if link.title:
- operation['summary'] = link.title
+ operation["summary"] = link.title
if link.description:
- operation['description'] = link.description
+ operation["description"] = link.description
if parameters:
- operation['parameters'] = parameters
+ operation["parameters"] = parameters
if tag:
- operation['tags'] = [tag]
+ operation["tags"] = [tag]
return operation
def get_paths(self, document):
@@ -1011,41 +1022,39 @@ class _BaseOpenAPIRenderer:
def get_structure(self, data):
return {
- 'openapi': '3.0.0',
- 'info': {
- 'version': '',
- 'title': data.title,
- 'description': data.description
+ "openapi": "3.0.0",
+ "info": {
+ "version": "",
+ "title": data.title,
+ "description": data.description,
},
- 'servers': [{
- 'url': data.url
- }],
- 'paths': self.get_paths(data)
+ "servers": [{"url": data.url}],
+ "paths": self.get_paths(data),
}
class OpenAPIRenderer(_BaseOpenAPIRenderer):
- media_type = 'application/vnd.oai.openapi'
+ media_type = "application/vnd.oai.openapi"
charset = None
- format = 'openapi'
+ format = "openapi"
def __init__(self):
- assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.'
- assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.'
+ assert coreapi, "Using OpenAPIRenderer, but `coreapi` is not installed."
+ assert yaml, "Using OpenAPIRenderer, but `pyyaml` is not installed."
def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data)
- return yaml.dump(structure, default_flow_style=False).encode('utf-8')
+ return yaml.dump(structure, default_flow_style=False).encode("utf-8")
class JSONOpenAPIRenderer(_BaseOpenAPIRenderer):
- media_type = 'application/vnd.oai.openapi+json'
+ media_type = "application/vnd.oai.openapi+json"
charset = None
- format = 'openapi-json'
+ format = "openapi-json"
def __init__(self):
- assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.'
+ assert coreapi, "Using JSONOpenAPIRenderer, but `coreapi` is not installed."
def render(self, data, media_type=None, renderer_context=None):
structure = self.get_structure(data)
- return json.dumps(structure, indent=4).encode('utf-8')
+ return json.dumps(structure, indent=4).encode("utf-8")
diff --git a/rest_framework/request.py b/rest_framework/request.py
index a6d92e2bd..774d177c1 100644
--- a/rest_framework/request.py
+++ b/rest_framework/request.py
@@ -30,8 +30,10 @@ def is_form_media_type(media_type):
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING))
- return (base_media_type == 'application/x-www-form-urlencoded' or
- base_media_type == 'multipart/form-data')
+ return (
+ base_media_type == "application/x-www-form-urlencoded"
+ or base_media_type == "multipart/form-data"
+ )
class override_method(object):
@@ -49,12 +51,12 @@ class override_method(object):
self.view = view
self.request = request
self.method = method
- self.action = getattr(view, 'action', None)
+ self.action = getattr(view, "action", None)
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
# For viewsets we also set the `.action` attribute.
- action_map = getattr(self.view, 'action_map', {})
+ action_map = getattr(self.view, "action_map", {})
self.view.action = action_map.get(self.method.lower())
return self.view.request
@@ -86,6 +88,7 @@ class Empty(object):
Placeholder for unset attributes.
Cannot use `None`, as that may be a valid value.
"""
+
pass
@@ -98,30 +101,32 @@ def clone_request(request, method):
Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods.
"""
- ret = Request(request=request._request,
- parsers=request.parsers,
- authenticators=request.authenticators,
- negotiator=request.negotiator,
- parser_context=request.parser_context)
+ ret = Request(
+ request=request._request,
+ parsers=request.parsers,
+ authenticators=request.authenticators,
+ negotiator=request.negotiator,
+ parser_context=request.parser_context,
+ )
ret._data = request._data
ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret.method = method
- if hasattr(request, '_user'):
+ if hasattr(request, "_user"):
ret._user = request._user
- if hasattr(request, '_auth'):
+ if hasattr(request, "_auth"):
ret._auth = request._auth
- if hasattr(request, '_authenticator'):
+ if hasattr(request, "_authenticator"):
ret._authenticator = request._authenticator
- if hasattr(request, 'accepted_renderer'):
+ if hasattr(request, "accepted_renderer"):
ret.accepted_renderer = request.accepted_renderer
- if hasattr(request, 'accepted_media_type'):
+ if hasattr(request, "accepted_media_type"):
ret.accepted_media_type = request.accepted_media_type
- if hasattr(request, 'version'):
+ if hasattr(request, "version"):
ret.version = request.version
- if hasattr(request, 'versioning_scheme'):
+ if hasattr(request, "versioning_scheme"):
ret.versioning_scheme = request.versioning_scheme
return ret
@@ -152,12 +157,19 @@ class Request(object):
authenticating the request's user.
"""
- def __init__(self, request, parsers=None, authenticators=None,
- negotiator=None, parser_context=None):
+ def __init__(
+ self,
+ request,
+ parsers=None,
+ authenticators=None,
+ negotiator=None,
+ parser_context=None,
+ ):
assert isinstance(request, HttpRequest), (
- 'The `request` argument must be an instance of '
- '`django.http.HttpRequest`, not `{}.{}`.'
- .format(request.__class__.__module__, request.__class__.__name__)
+ "The `request` argument must be an instance of "
+ "`django.http.HttpRequest`, not `{}.{}`.".format(
+ request.__class__.__module__, request.__class__.__name__
+ )
)
self._request = request
@@ -173,11 +185,11 @@ class Request(object):
if self.parser_context is None:
self.parser_context = {}
- self.parser_context['request'] = self
- self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
+ self.parser_context["request"] = self
+ self.parser_context["encoding"] = request.encoding or settings.DEFAULT_CHARSET
- force_user = getattr(request, '_force_auth_user', None)
- force_token = getattr(request, '_force_auth_token', None)
+ force_user = getattr(request, "_force_auth_user", None)
+ force_token = getattr(request, "_force_auth_token", None)
if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
@@ -188,14 +200,14 @@ class Request(object):
@property
def content_type(self):
meta = self._request.META
- return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
+ return meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
@property
def stream(self):
"""
Returns an object that may be used to stream the request content.
"""
- if not _hasattr(self, '_stream'):
+ if not _hasattr(self, "_stream"):
self._load_stream()
return self._stream
@@ -208,7 +220,7 @@ class Request(object):
@property
def data(self):
- if not _hasattr(self, '_full_data'):
+ if not _hasattr(self, "_full_data"):
self._load_data_and_files()
return self._full_data
@@ -218,7 +230,7 @@ class Request(object):
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
- if not hasattr(self, '_user'):
+ if not hasattr(self, "_user"):
with wrap_attributeerrors():
self._authenticate()
return self._user
@@ -242,7 +254,7 @@ class Request(object):
Returns any non-user authentication information associated with the
request, such as an authentication token.
"""
- if not hasattr(self, '_auth'):
+ if not hasattr(self, "_auth"):
with wrap_attributeerrors():
self._authenticate()
return self._auth
@@ -262,7 +274,7 @@ class Request(object):
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
- if not hasattr(self, '_authenticator'):
+ if not hasattr(self, "_authenticator"):
with wrap_attributeerrors():
self._authenticate()
return self._authenticator
@@ -271,7 +283,7 @@ class Request(object):
"""
Parses the request content into `self.data`.
"""
- if not _hasattr(self, '_data'):
+ if not _hasattr(self, "_data"):
self._data, self._files = self._parse()
if self._files:
self._full_data = self._data.copy()
@@ -292,7 +304,7 @@ class Request(object):
meta = self._request.META
try:
content_length = int(
- meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))
+ meta.get("CONTENT_LENGTH", meta.get("HTTP_CONTENT_LENGTH", 0))
)
except (ValueError, TypeError):
content_length = 0
@@ -308,10 +320,7 @@ class Request(object):
"""
Return True if this requests supports parsing form data.
"""
- form_media = (
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- )
+ form_media = ("application/x-www-form-urlencoded", "multipart/form-data")
return any([parser.media_type in form_media for parser in self.parsers])
def _parse(self):
@@ -324,7 +333,7 @@ class Request(object):
try:
stream = self.stream
except RawPostDataException:
- if not hasattr(self._request, '_post'):
+ if not hasattr(self._request, "_post"):
raise
# If request.POST has been accessed in middleware, and a method='POST'
# request was made with 'multipart/form-data', then the request stream
@@ -335,7 +344,7 @@ class Request(object):
if stream is None or media_type is None:
if media_type and is_form_media_type(media_type):
- empty_data = QueryDict('', encoding=self._request._encoding)
+ empty_data = QueryDict("", encoding=self._request._encoding)
else:
empty_data = {}
empty_files = MultiValueDict()
@@ -353,7 +362,7 @@ class Request(object):
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
- self._data = QueryDict('', encoding=self._request._encoding)
+ self._data = QueryDict("", encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
@@ -416,33 +425,33 @@ class Request(object):
@property
def DATA(self):
raise NotImplementedError(
- '`request.DATA` has been deprecated in favor of `request.data` '
- 'since version 3.0, and has been fully removed as of version 3.2.'
+ "`request.DATA` has been deprecated in favor of `request.data` "
+ "since version 3.0, and has been fully removed as of version 3.2."
)
@property
def POST(self):
# Ensure that request.POST uses our request parsing.
- if not _hasattr(self, '_data'):
+ if not _hasattr(self, "_data"):
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self._data
- return QueryDict('', encoding=self._request._encoding)
+ return QueryDict("", encoding=self._request._encoding)
@property
def FILES(self):
# Leave this one alone for backwards compat with Django's request.FILES
# Different from the other two cases, which are not valid property
# names on the WSGIRequest class.
- if not _hasattr(self, '_files'):
+ if not _hasattr(self, "_files"):
self._load_data_and_files()
return self._files
@property
def QUERY_PARAMS(self):
raise NotImplementedError(
- '`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` '
- 'since version 3.0, and has been fully removed as of version 3.2.'
+ "`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` "
+ "since version 3.0, and has been fully removed as of version 3.2."
)
def force_plaintext_errors(self, value):
diff --git a/rest_framework/response.py b/rest_framework/response.py
index bf0663255..3d476e4e1 100644
--- a/rest_framework/response.py
+++ b/rest_framework/response.py
@@ -19,9 +19,15 @@ class Response(SimpleTemplateResponse):
arbitrary media types.
"""
- def __init__(self, data=None, status=None,
- template_name=None, headers=None,
- exception=False, content_type=None):
+ def __init__(
+ self,
+ data=None,
+ status=None,
+ template_name=None,
+ headers=None,
+ exception=False,
+ content_type=None,
+ ):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
@@ -33,9 +39,9 @@ class Response(SimpleTemplateResponse):
if isinstance(data, Serializer):
msg = (
- 'You passed a Serializer instance as data, but '
- 'probably meant to pass serialized `.data` or '
- '`.error`. representation.'
+ "You passed a Serializer instance as data, but "
+ "probably meant to pass serialized `.data` or "
+ "`.error`. representation."
)
raise AssertionError(msg)
@@ -50,14 +56,14 @@ class Response(SimpleTemplateResponse):
@property
def rendered_content(self):
- renderer = getattr(self, 'accepted_renderer', None)
- accepted_media_type = getattr(self, 'accepted_media_type', None)
- context = getattr(self, 'renderer_context', None)
+ renderer = getattr(self, "accepted_renderer", None)
+ accepted_media_type = getattr(self, "accepted_media_type", None)
+ context = getattr(self, "renderer_context", None)
assert renderer, ".accepted_renderer not set on Response"
assert accepted_media_type, ".accepted_media_type not set on Response"
assert context is not None, ".renderer_context not set on Response"
- context['response'] = self
+ context["response"] = self
media_type = renderer.media_type
charset = renderer.charset
@@ -67,18 +73,17 @@ class Response(SimpleTemplateResponse):
content_type = "{0}; charset={1}".format(media_type, charset)
elif content_type is None:
content_type = media_type
- self['Content-Type'] = content_type
+ self["Content-Type"] = content_type
ret = renderer.render(self.data, accepted_media_type, context)
if isinstance(ret, six.text_type):
assert charset, (
- 'renderer returned unicode, and did not specify '
- 'a charset value.'
+ "renderer returned unicode, and did not specify " "a charset value."
)
return bytes(ret.encode(charset))
if not ret:
- del self['Content-Type']
+ del self["Content-Type"]
return ret
@@ -88,7 +93,7 @@ class Response(SimpleTemplateResponse):
Returns reason text corresponding to our HTTP response status code.
Provided for convenience.
"""
- return responses.get(self.status_code, '')
+ return responses.get(self.status_code, "")
def __getstate__(self):
"""
@@ -96,10 +101,15 @@ class Response(SimpleTemplateResponse):
"""
state = super(Response, self).__getstate__()
for key in (
- 'accepted_renderer', 'renderer_context', 'resolver_match',
- 'client', 'request', 'json', 'wsgi_request'
+ "accepted_renderer",
+ "renderer_context",
+ "resolver_match",
+ "client",
+ "request",
+ "json",
+ "wsgi_request",
):
if key in state:
del state[key]
- state['_closable_objects'] = []
+ state["_closable_objects"] = []
return state
diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py
index e9cf737f1..38d846b16 100644
--- a/rest_framework/reverse.py
+++ b/rest_framework/reverse.py
@@ -3,8 +3,7 @@ Provide urlresolver functions that return fully qualified URLs or view names
"""
from __future__ import unicode_literals
-from django.urls import NoReverseMatch
-from django.urls import reverse as django_reverse
+from django.urls import NoReverseMatch, reverse as django_reverse
from django.utils import six
from django.utils.functional import lazy
@@ -20,9 +19,7 @@ def preserve_builtin_query_params(url, request=None):
if request is None:
return url
- overrides = [
- api_settings.URL_FORMAT_OVERRIDE,
- ]
+ overrides = [api_settings.URL_FORMAT_OVERRIDE]
for param in overrides:
if param and (param in request.GET):
@@ -38,7 +35,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
to the versioning scheme instance, so that the resulting URL
can be modified if needed.
"""
- scheme = getattr(request, 'versioning_scheme', None)
+ scheme = getattr(request, "versioning_scheme", None)
if scheme is not None:
try:
url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
@@ -59,7 +56,7 @@ def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extr
"""
if format is not None:
kwargs = kwargs or {}
- kwargs['format'] = format
+ kwargs["format"] = format
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)
diff --git a/rest_framework/routers.py b/rest_framework/routers.py
index 1cacea181..6296b0f9c 100644
--- a/rest_framework/routers.py
+++ b/rest_framework/routers.py
@@ -25,9 +25,7 @@ from django.urls import NoReverseMatch
from django.utils import six
from django.utils.deprecation import RenameMethodsBase
-from rest_framework import (
- RemovedInDRF310Warning, RemovedInDRF311Warning, views
-)
+from rest_framework import RemovedInDRF310Warning, RemovedInDRF311Warning, views
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
@@ -35,8 +33,9 @@ from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
-Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
-DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
+
+Route = namedtuple("Route", ["url", "mapping", "name", "detail", "initkwargs"])
+DynamicRoute = namedtuple("DynamicRoute", ["url", "name", "detail", "initkwargs"])
class DynamicDetailRoute(object):
@@ -45,7 +44,8 @@ class DynamicDetailRoute(object):
"`DynamicDetailRoute` is deprecated and will be removed in 3.10 "
"in favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, True, initkwargs)` instead.",
- RemovedInDRF310Warning, stacklevel=2
+ RemovedInDRF310Warning,
+ stacklevel=2,
)
return DynamicRoute(url, name, True, initkwargs)
@@ -56,7 +56,8 @@ class DynamicListRoute(object):
"`DynamicListRoute` is deprecated and will be removed in 3.10 in "
"favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, False, initkwargs)` instead.",
- RemovedInDRF310Warning, stacklevel=2
+ RemovedInDRF310Warning,
+ stacklevel=2,
)
return DynamicRoute(url, name, False, initkwargs)
@@ -65,8 +66,8 @@ def escape_curly_brackets(url_path):
"""
Double brackets in regex of url_path for escape string formatting
"""
- if ('{' and '}') in url_path:
- url_path = url_path.replace('{', '{{').replace('}', '}}')
+ if ("{" and "}") in url_path:
+ url_path = url_path.replace("{", "{{").replace("}", "}}")
return url_path
@@ -79,7 +80,7 @@ def flatten(list_of_lists):
class RenameRouterMethods(RenameMethodsBase):
renamed_methods = (
- ('get_default_base_name', 'get_default_basename', RemovedInDRF311Warning),
+ ("get_default_base_name", "get_default_basename", RemovedInDRF311Warning),
)
@@ -92,8 +93,9 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
warnings.warn(msg, RemovedInDRF311Warning, 2)
- assert not (basename and base_name), (
- "Do not provide both the `basename` and `base_name` arguments.")
+ assert not (
+ basename and base_name
+ ), "Do not provide both the `basename` and `base_name` arguments."
if basename is None:
basename = base_name
@@ -103,7 +105,7 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
self.registry.append((prefix, viewset, basename))
# invalidate the urls cache
- if hasattr(self, '_urls'):
+ if hasattr(self, "_urls"):
del self._urls
def get_default_basename(self, viewset):
@@ -111,17 +113,17 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
- raise NotImplementedError('get_default_basename must be overridden')
+ raise NotImplementedError("get_default_basename must be overridden")
def get_urls(self):
"""
Return a list of URL patterns, given the registered viewsets.
"""
- raise NotImplementedError('get_urls must be overridden')
+ raise NotImplementedError("get_urls must be overridden")
@property
def urls(self):
- if not hasattr(self, '_urls'):
+ if not hasattr(self, "_urls"):
self._urls = self.get_urls()
return self._urls
@@ -131,48 +133,45 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
- url=r'^{prefix}{trailing_slash}$',
- mapping={
- 'get': 'list',
- 'post': 'create'
- },
- name='{basename}-list',
+ url=r"^{prefix}{trailing_slash}$",
+ mapping={"get": "list", "post": "create"},
+ name="{basename}-list",
detail=False,
- initkwargs={'suffix': 'List'}
+ initkwargs={"suffix": "List"},
),
# Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset.
DynamicRoute(
- url=r'^{prefix}/{url_path}{trailing_slash}$',
- name='{basename}-{url_name}',
+ url=r"^{prefix}/{url_path}{trailing_slash}$",
+ name="{basename}-{url_name}",
detail=False,
- initkwargs={}
+ initkwargs={},
),
# Detail route.
Route(
- url=r'^{prefix}/{lookup}{trailing_slash}$',
+ url=r"^{prefix}/{lookup}{trailing_slash}$",
mapping={
- 'get': 'retrieve',
- 'put': 'update',
- 'patch': 'partial_update',
- 'delete': 'destroy'
+ "get": "retrieve",
+ "put": "update",
+ "patch": "partial_update",
+ "delete": "destroy",
},
- name='{basename}-detail',
+ name="{basename}-detail",
detail=True,
- initkwargs={'suffix': 'Instance'}
+ initkwargs={"suffix": "Instance"},
),
# Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset.
DynamicRoute(
- url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
- name='{basename}-{url_name}',
+ url=r"^{prefix}/{lookup}/{url_path}{trailing_slash}$",
+ name="{basename}-{url_name}",
detail=True,
- initkwargs={}
+ initkwargs={},
),
]
def __init__(self, trailing_slash=True):
- self.trailing_slash = '/' if trailing_slash else ''
+ self.trailing_slash = "/" if trailing_slash else ""
super(SimpleRouter, self).__init__()
def get_default_basename(self, viewset):
@@ -180,11 +179,13 @@ class SimpleRouter(BaseRouter):
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
- queryset = getattr(viewset, 'queryset', None)
+ queryset = getattr(viewset, "queryset", None)
- assert queryset is not None, '`basename` argument not specified, and could ' \
- 'not automatically determine the name from the viewset, as ' \
- 'it does not have a `.queryset` attribute.'
+ assert queryset is not None, (
+ "`basename` argument not specified, and could "
+ "not automatically determine the name from the viewset, as "
+ "it does not have a `.queryset` attribute."
+ )
return queryset.model._meta.object_name.lower()
@@ -196,18 +197,29 @@ class SimpleRouter(BaseRouter):
"""
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions.
- known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
+ known_actions = list(
+ flatten(
+ [
+ route.mapping.values()
+ for route in self.routes
+ if isinstance(route, Route)
+ ]
+ )
+ )
extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list
not_allowed = [
- action.__name__ for action in extra_actions
+ action.__name__
+ for action in extra_actions
if action.__name__ in known_actions
]
if not_allowed:
- msg = ('Cannot use the @action decorator on the following '
- 'methods, as they are existing routes: %s')
- raise ImproperlyConfigured(msg % ', '.join(not_allowed))
+ msg = (
+ "Cannot use the @action decorator on the following "
+ "methods, as they are existing routes: %s"
+ )
+ raise ImproperlyConfigured(msg % ", ".join(not_allowed))
# partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail]
@@ -216,9 +228,13 @@ class SimpleRouter(BaseRouter):
routes = []
for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail:
- routes += [self._get_dynamic_route(route, action) for action in detail_actions]
+ routes += [
+ self._get_dynamic_route(route, action) for action in detail_actions
+ ]
elif isinstance(route, DynamicRoute) and not route.detail:
- routes += [self._get_dynamic_route(route, action) for action in list_actions]
+ routes += [
+ self._get_dynamic_route(route, action) for action in list_actions
+ ]
else:
routes.append(route)
@@ -231,9 +247,9 @@ class SimpleRouter(BaseRouter):
url_path = escape_curly_brackets(action.url_path)
return Route(
- url=route.url.replace('{url_path}', url_path),
+ url=route.url.replace("{url_path}", url_path),
mapping=action.mapping,
- name=route.name.replace('{url_name}', action.url_name),
+ name=route.name.replace("{url_name}", action.url_name),
detail=route.detail,
initkwargs=initkwargs,
)
@@ -250,7 +266,7 @@ class SimpleRouter(BaseRouter):
bound_methods[method] = action
return bound_methods
- def get_lookup_regex(self, viewset, lookup_prefix=''):
+ def get_lookup_regex(self, viewset, lookup_prefix=""):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
@@ -261,16 +277,16 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers
"""
- base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
+ base_regex = "(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})"
# Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries.
- lookup_field = getattr(viewset, 'lookup_field', 'pk')
- lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
- lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
+ lookup_field = getattr(viewset, "lookup_field", "pk")
+ lookup_url_kwarg = getattr(viewset, "lookup_url_kwarg", None) or lookup_field
+ lookup_value = getattr(viewset, "lookup_value_regex", "[^/.]+")
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg,
- lookup_value=lookup_value
+ lookup_value=lookup_value,
)
def get_urls(self):
@@ -292,23 +308,18 @@ class SimpleRouter(BaseRouter):
# Build the url pattern
regex = route.url.format(
- prefix=prefix,
- lookup=lookup,
- trailing_slash=self.trailing_slash
+ prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
- if not prefix and regex[:2] == '^/':
- regex = '^' + regex[2:]
+ if not prefix and regex[:2] == "^/":
+ regex = "^" + regex[2:]
initkwargs = route.initkwargs.copy()
- initkwargs.update({
- 'basename': basename,
- 'detail': route.detail,
- })
+ initkwargs.update({"basename": basename, "detail": route.detail})
view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename)
@@ -321,6 +332,7 @@ class APIRootView(views.APIView):
"""
The default basic root view for DefaultRouter
"""
+
_ignore_model_permissions = True
schema = None # exclude from schema
api_root_dict = None
@@ -331,14 +343,14 @@ class APIRootView(views.APIView):
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
- url_name = namespace + ':' + url_name
+ url_name = namespace + ":" + url_name
try:
ret[key] = reverse(
url_name,
args=args,
kwargs=kwargs,
request=request,
- format=kwargs.get('format', None)
+ format=kwargs.get("format", None),
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
@@ -352,17 +364,18 @@ class DefaultRouter(SimpleRouter):
The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs.
"""
+
include_root_view = True
include_format_suffixes = True
- root_view_name = 'api-root'
+ root_view_name = "api-root"
default_schema_renderers = None
APIRootView = APIRootView
APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs):
- if 'root_renderers' in kwargs:
- self.root_renderers = kwargs.pop('root_renderers')
+ if "root_renderers" in kwargs:
+ self.root_renderers = kwargs.pop("root_renderers")
else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super(DefaultRouter, self).__init__(*args, **kwargs)
@@ -387,7 +400,7 @@ class DefaultRouter(SimpleRouter):
if self.include_root_view:
view = self.get_api_root_view(api_urls=urls)
- root_url = url(r'^$', view, name=self.root_view_name)
+ root_url = url(r"^$", view, name=self.root_view_name)
urls.append(root_url)
if self.include_format_suffixes:
diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py
index ba0ec6536..00acae91c 100644
--- a/rest_framework/schemas/__init__.py
+++ b/rest_framework/schemas/__init__.py
@@ -27,18 +27,29 @@ from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa
def get_schema_view(
- title=None, url=None, description=None, urlconf=None, renderer_classes=None,
- public=False, patterns=None, generator_class=SchemaGenerator,
- authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
- permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
+ title=None,
+ url=None,
+ description=None,
+ urlconf=None,
+ renderer_classes=None,
+ public=False,
+ patterns=None,
+ generator_class=SchemaGenerator,
+ authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
+ permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
+):
"""
Return a schema view.
"""
# Avoid import cycle on APIView
from .views import SchemaView
+
generator = generator_class(
- title=title, url=url, description=description,
- urlconf=urlconf, patterns=patterns,
+ title=title,
+ url=url,
+ description=description,
+ urlconf=urlconf,
+ patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,
diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py
index db226a6c1..e1b05273b 100644
--- a/rest_framework/schemas/generators.py
+++ b/rest_framework/schemas/generators.py
@@ -15,7 +15,11 @@ from django.utils import six
from rest_framework import exceptions
from rest_framework.compat import (
- URLPattern, URLResolver, coreapi, coreschema, get_original_route
+ URLPattern,
+ URLResolver,
+ coreapi,
+ coreschema,
+ get_original_route,
)
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
@@ -25,7 +29,7 @@ from .utils import is_list_view
def common_path(paths):
- split_paths = [path.strip('/').split('/') for path in paths]
+ split_paths = [path.strip("/").split("/") for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
@@ -33,7 +37,7 @@ def common_path(paths):
if c != s2[i]:
common = s1[:i]
break
- return '/' + '/'.join(common)
+ return "/" + "/".join(common)
def get_pk_name(model):
@@ -47,7 +51,8 @@ def is_api_view(callback):
"""
# Avoid import cycle on APIView
from rest_framework.views import APIView
- cls = getattr(callback, 'cls', None)
+
+ cls = getattr(callback, "cls", None)
return (cls is not None) and issubclass(cls, APIView)
@@ -78,7 +83,7 @@ class LinkNode(OrderedDict):
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
- key = '{}_{}'.format(preferred_key, current_val)
+ key = "{}_{}".format(preferred_key, current_val)
if key not in self:
return key
@@ -101,9 +106,7 @@ def insert_into(target, keys, value):
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
- value_url=value.url,
- target_url=target.url,
- keys=keys
+ value_url=value.url, target_url=target.url, keys=keys
)
raise ValueError(msg)
@@ -119,24 +122,25 @@ def distribute_links(obj):
def is_custom_action(action):
return action not in {
- 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
+ "retrieve",
+ "list",
+ "create",
+ "update",
+ "partial_update",
+ "destroy",
}
def endpoint_ordering(endpoint):
path, method, callback = endpoint
- method_priority = {
- 'GET': 0,
- 'POST': 1,
- 'PUT': 2,
- 'PATCH': 3,
- 'DELETE': 4
- }.get(method, 5)
+ method_priority = {"GET": 0, "POST": 1, "PUT": 2, "PATCH": 3, "DELETE": 4}.get(
+ method, 5
+ )
return (path, method_priority)
_PATH_PARAMETER_COMPONENT_RE = re.compile(
- r'<(?:(?P[^>:]+):)?(?P\w+)>'
+ r"<(?:(?P[^>:]+):)?(?P\w+)>"
)
@@ -144,6 +148,7 @@ class EndpointEnumerator(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
+
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
@@ -159,7 +164,7 @@ class EndpointEnumerator(object):
self.patterns = patterns
- def get_api_endpoints(self, patterns=None, prefix=''):
+ def get_api_endpoints(self, patterns=None, prefix=""):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
@@ -180,8 +185,7 @@ class EndpointEnumerator(object):
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
- patterns=pattern.url_patterns,
- prefix=path_regex
+ patterns=pattern.url_patterns, prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
@@ -196,7 +200,7 @@ class EndpointEnumerator(object):
path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
- path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g}', path)
+ path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r"{\g}", path)
return path
def should_include_endpoint(self, path, callback):
@@ -209,11 +213,11 @@ class EndpointEnumerator(object):
if callback.cls.schema is None:
return False
- if 'schema' in callback.initkwargs:
- if callback.initkwargs['schema'] is None:
+ if "schema" in callback.initkwargs:
+ if callback.initkwargs["schema"] is None:
return False
- if path.endswith('.{format}') or path.endswith('.{format}/'):
+ if path.endswith(".{format}") or path.endswith(".{format}/"):
return False # Ignore .json style URLs.
return True
@@ -222,24 +226,24 @@ class EndpointEnumerator(object):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
- if hasattr(callback, 'actions'):
+ if hasattr(callback, "actions"):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
methods = callback.cls().allowed_methods
- return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
+ return [method for method in methods if method not in ("OPTIONS", "HEAD")]
class SchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
- 'get': 'retrieve',
- 'post': 'create',
- 'put': 'update',
- 'patch': 'partial_update',
- 'delete': 'destroy',
+ "get": "retrieve",
+ "post": "create",
+ "put": "update",
+ "patch": "partial_update",
+ "delete": "destroy",
}
endpoint_inspector_cls = EndpointEnumerator
@@ -253,12 +257,14 @@ class SchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
- def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
- assert coreapi, '`coreapi` must be installed for schema support.'
- assert coreschema, '`coreschema` must be installed for schema support.'
+ def __init__(
+ self, title=None, url=None, description=None, patterns=None, urlconf=None
+ ):
+ assert coreapi, "`coreapi` must be installed for schema support."
+ assert coreschema, "`coreschema` must be installed for schema support."
- if url and not url.endswith('/'):
- url += '/'
+ if url and not url.endswith("/"):
+ url += "/"
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
@@ -288,8 +294,7 @@ class SchemaGenerator(object):
distribute_links(links)
return coreapi.Document(
- title=self.title, description=self.description,
- url=url, content=links
+ title=self.title, description=self.description, url=url, content=links
)
def get_links(self, request=None):
@@ -317,7 +322,7 @@ class SchemaGenerator(object):
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
- subpath = path[len(prefix):]
+ subpath = path[len(prefix) :]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
@@ -342,35 +347,35 @@ class SchemaGenerator(object):
"""
prefixes = []
for path in paths:
- components = path.strip('/').split('/')
+ components = path.strip("/").split("/")
initial_components = []
for component in components:
- if '{' in component:
+ if "{" in component:
break
initial_components.append(component)
- prefix = '/'.join(initial_components[:-1])
+ prefix = "/".join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
- return '/'
- prefixes.append('/' + prefix + '/')
+ return "/"
+ prefixes.append("/" + prefix + "/")
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
- view = callback.cls(**getattr(callback, 'initkwargs', {}))
+ view = callback.cls(**getattr(callback, "initkwargs", {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
- view.action_map = getattr(callback, 'actions', None)
+ view.action_map = getattr(callback, "actions", None)
- actions = getattr(callback, 'actions', None)
+ actions = getattr(callback, "actions", None)
if actions is not None:
- if method == 'OPTIONS':
- view.action = 'metadata'
+ if method == "OPTIONS":
+ view.action = "metadata"
else:
view.action = actions.get(method.lower())
@@ -398,14 +403,14 @@ class SchemaGenerator(object):
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
- if not self.coerce_path_pk or '{pk}' not in path:
+ if not self.coerce_path_pk or "{pk}" not in path:
return path
- model = getattr(getattr(view, 'queryset', None), 'model', None)
+ model = getattr(getattr(view, "queryset", None), "model", None)
if model:
field_name = get_pk_name(model)
else:
- field_name = 'id'
- return path.replace('{pk}', '{%s}' % field_name)
+ field_name = "id"
+ return path.replace("{pk}", "{%s}" % field_name)
# Method for generating the link layout....
@@ -421,20 +426,20 @@ class SchemaGenerator(object):
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
- if hasattr(view, 'action'):
+ if hasattr(view, "action"):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
- action = 'list'
+ action = "list"
else:
action = self.default_mapping[method.lower()]
named_path_components = [
- component for component
- in subpath.strip('/').split('/')
- if '{' not in component
+ component
+ for component in subpath.strip("/").split("/")
+ if "{" not in component
]
if is_custom_action(action):
diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py
index 85142edce..82e75d945 100644
--- a/rest_framework/schemas/inspectors.py
+++ b/rest_framework/schemas/inspectors.py
@@ -21,46 +21,38 @@ from rest_framework.utils import formatting
from .utils import is_list_view
-header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
+
+header_regex = re.compile("^[a-zA-Z][0-9A-Za-z_]*:")
def field_to_schema(field):
- title = force_text(field.label) if field.label else ''
- description = force_text(field.help_text) if field.help_text else ''
+ title = force_text(field.label) if field.label else ""
+ description = force_text(field.help_text) if field.help_text else ""
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
- items=child_schema,
- title=title,
- description=description
+ items=child_schema, title=title, description=description
)
elif isinstance(field, serializers.DictField):
- return coreschema.Object(
- title=title,
- description=description
- )
+ return coreschema.Object(title=title, description=description)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
- properties=OrderedDict([
- (key, field_to_schema(value))
- for key, value
- in field.fields.items()
- ]),
+ properties=OrderedDict(
+ [(key, field_to_schema(value)) for key, value in field.fields.items()]
+ ),
title=title,
- description=description
+ description=description,
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
- items=related_field_schema,
- title=title,
- description=description
+ items=related_field_schema, title=title, description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
- model = getattr(field.queryset, 'model', None)
+ model = getattr(field.queryset, "model", None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
@@ -72,13 +64,11 @@ def field_to_schema(field):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
- description=description
+ description=description,
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
- enum=list(field.choices),
- title=title,
- description=description
+ enum=list(field.choices), title=title, description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
@@ -87,25 +77,17 @@ def field_to_schema(field):
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
- return coreschema.String(
- title=title,
- description=description,
- format='date'
- )
+ return coreschema.String(title=title, description=description, format="date")
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
- title=title,
- description=description,
- format='date-time'
+ title=title, description=description, format="date-time"
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
- if field.style.get('base_template') == 'textarea.html':
+ if field.style.get("base_template") == "textarea.html":
return coreschema.String(
- title=title,
- description=description,
- format='textarea'
+ title=title, description=description, format="textarea"
)
return coreschema.String(title=title, description=description)
@@ -113,15 +95,14 @@ def field_to_schema(field):
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
- value_type = _('unique integer value')
+ value_type = _("unique integer value")
elif isinstance(model_field, models.UUIDField):
- value_type = _('UUID string')
+ value_type = _("UUID string")
else:
- value_type = _('unique value')
+ value_type = _("unique value")
- return _('A {value_type} identifying this {name}.').format(
- value_type=value_type,
- name=model._meta.verbose_name,
+ return _("A {value_type} identifying this {name}.").format(
+ value_type=value_type, name=model._meta.verbose_name
)
@@ -200,6 +181,7 @@ class AutoSchema(ViewInspector):
Responsible for per-view introspection and schema generation.
"""
+
def __init__(self, manual_fields=None):
"""
Parameters:
@@ -221,14 +203,14 @@ class AutoSchema(ViewInspector):
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
- if fields and any([field.location in ('form', 'body') for field in fields]):
+ if fields and any([field.location in ("form", "body") for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
- if base_url and path.startswith('/'):
+ if base_url and path.startswith("/"):
path = path[1:]
return coreapi.Link(
@@ -236,7 +218,7 @@ class AutoSchema(ViewInspector):
action=method.lower(),
encoding=encoding,
fields=fields,
- description=description
+ description=description,
)
def get_description(self, path, method):
@@ -248,25 +230,31 @@ class AutoSchema(ViewInspector):
"""
view = self.view
- method_name = getattr(view, 'action', method.lower())
+ method_name = getattr(view, "action", method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
- return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
+ return self._get_description_section(
+ view, method.lower(), formatting.dedent(smart_text(method_docstring))
+ )
else:
- return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
+ return self._get_description_section(
+ view,
+ getattr(view, "action", method.lower()),
+ view.get_view_description(),
+ )
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
- current_section = ''
- sections = {'': ''}
+ current_section = ""
+ sections = {"": ""}
for line in lines:
if header_regex.match(line):
- current_section, seperator, lead = line.partition(':')
+ current_section, seperator, lead = line.partition(":")
sections[current_section] = lead.strip()
else:
- sections[current_section] += '\n' + line
+ sections[current_section] += "\n" + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
@@ -275,7 +263,7 @@ class AutoSchema(ViewInspector):
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
- return sections[''].strip()
+ return sections[""].strip()
def get_path_fields(self, path, method):
"""
@@ -283,12 +271,12 @@ class AutoSchema(ViewInspector):
templated path variables.
"""
view = self.view
- model = getattr(getattr(view, 'queryset', None), 'model', None)
+ model = getattr(getattr(view, "queryset", None), "model", None)
fields = []
for variable in uritemplate.variables(path):
- title = ''
- description = ''
+ title = ""
+ description = ""
schema_cls = coreschema.String
kwargs = {}
if model is not None:
@@ -306,16 +294,19 @@ class AutoSchema(ViewInspector):
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
- if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
- kwargs['pattern'] = view.lookup_value_regex
+ if (
+ hasattr(view, "lookup_value_regex")
+ and view.lookup_field == variable
+ ):
+ kwargs["pattern"] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
- location='path',
+ location="path",
required=True,
- schema=schema_cls(title=title, description=description, **kwargs)
+ schema=schema_cls(title=title, description=description, **kwargs),
)
fields.append(field)
@@ -328,28 +319,29 @@ class AutoSchema(ViewInspector):
"""
view = self.view
- if method not in ('PUT', 'PATCH', 'POST'):
+ if method not in ("PUT", "PATCH", "POST"):
return []
- if not hasattr(view, 'get_serializer'):
+ if not hasattr(view, "get_serializer"):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
- warnings.warn('{}.get_serializer() raised an exception during '
- 'schema generation. Serializer fields will not be '
- 'generated for {} {}.'
- .format(view.__class__.__name__, method, path))
+ warnings.warn(
+ "{}.get_serializer() raised an exception during "
+ "schema generation. Serializer fields will not be "
+ "generated for {} {}.".format(view.__class__.__name__, method, path)
+ )
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
- name='data',
- location='body',
+ name="data",
+ location="body",
required=True,
- schema=coreschema.Array()
+ schema=coreschema.Array(),
)
]
@@ -361,12 +353,12 @@ class AutoSchema(ViewInspector):
if field.read_only or isinstance(field, serializers.HiddenField):
continue
- required = field.required and method != 'PATCH'
+ required = field.required and method != "PATCH"
field = coreapi.Field(
name=field.field_name,
- location='form',
+ location="form",
required=required,
- schema=field_to_schema(field)
+ schema=field_to_schema(field),
)
fields.append(field)
@@ -378,7 +370,7 @@ class AutoSchema(ViewInspector):
if not is_list_view(path, method, view):
return []
- pagination = getattr(view, 'pagination_class', None)
+ pagination = getattr(view, "pagination_class", None)
if not pagination:
return []
@@ -397,11 +389,17 @@ class AutoSchema(ViewInspector):
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
- if getattr(self.view, 'filter_backends', None) is None:
+ if getattr(self.view, "filter_backends", None) is None:
return False
- if hasattr(self.view, 'action'):
- return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
+ if hasattr(self.view, "action"):
+ return self.view.action in [
+ "list",
+ "retrieve",
+ "update",
+ "partial_update",
+ "destroy",
+ ]
return method.lower() in ["get", "put", "patch", "delete"]
@@ -447,18 +445,18 @@ class AutoSchema(ViewInspector):
# Core API supports the following request encodings over HTTP...
supported_media_types = {
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data',
+ "application/json",
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
}
- parser_classes = getattr(view, 'parser_classes', [])
+ parser_classes = getattr(view, "parser_classes", [])
for parser_class in parser_classes:
- media_type = getattr(parser_class, 'media_type', None)
+ media_type = getattr(parser_class, "media_type", None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
- if media_type == '*/*':
- return 'application/octet-stream'
+ if media_type == "*/*":
+ return "application/octet-stream"
return None
@@ -468,7 +466,8 @@ class ManualSchema(ViewInspector):
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
- def __init__(self, fields, description='', encoding=None):
+
+ def __init__(self, fields, description="", encoding=None):
"""
Parameters:
@@ -476,14 +475,16 @@ class ManualSchema(ViewInspector):
* `description`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
- assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
+ assert all(
+ isinstance(f, coreapi.Field) for f in fields
+ ), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
- if base_url and path.startswith('/'):
+ if base_url and path.startswith("/"):
path = path[1:]
return coreapi.Link(
@@ -491,21 +492,22 @@ class ManualSchema(ViewInspector):
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
- description=self._description
+ description=self._description,
)
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
+
def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
- assert issubclass(inspector_class, ViewInspector), (
- "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
- )
+ assert issubclass(
+ inspector_class, ViewInspector
+ ), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
inspector = inspector_class()
inspector.view = instance
return inspector
diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py
index 76437a20a..3eacc686f 100644
--- a/rest_framework/schemas/utils.py
+++ b/rest_framework/schemas/utils.py
@@ -10,15 +10,15 @@ def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
- if hasattr(view, 'action'):
+ if hasattr(view, "action"):
# Viewsets have an explicitly defined action, which we can inspect.
- return view.action == 'list'
+ return view.action == "list"
- if method.lower() != 'get':
+ if method.lower() != "get":
return False
if isinstance(view, RetrieveModelMixin):
return False
- path_components = path.strip('/').split('/')
- if path_components and '{' in path_components[-1]:
+ path_components = path.strip("/").split("/")
+ if path_components and "{" in path_components[-1]:
return False
return True
diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py
index f5e327a94..e609f733c 100644
--- a/rest_framework/schemas/views.py
+++ b/rest_framework/schemas/views.py
@@ -21,7 +21,7 @@ class SchemaView(APIView):
if self.renderer_classes is None:
self.renderer_classes = [
renderers.OpenAPIRenderer,
- renderers.CoreJSONRenderer
+ renderers.CoreJSONRenderer,
]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer]
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index 9830edb3f..2cd8cccc0 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -17,12 +17,13 @@ import inspect
import traceback
from collections import OrderedDict
-from django.core.exceptions import ImproperlyConfigured
-from django.core.exceptions import ValidationError as DjangoValidationError
+from django.core.exceptions import (
+ ImproperlyConfigured,
+ ValidationError as DjangoValidationError,
+)
from django.db import models
from django.db.models import DurationField as ModelDurationField
-from django.db.models.fields import Field as DjangoModelField
-from django.db.models.fields import FieldDoesNotExist
+from django.db.models.fields import Field as DjangoModelField, FieldDoesNotExist
from django.utils import six, timezone
from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _
@@ -33,18 +34,28 @@ from rest_framework.fields import get_error_detail, set_value
from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import (
- ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
- get_relation_kwargs, get_url_kwargs
+ ClassLookupDict,
+ get_field_kwargs,
+ get_nested_relation_kwargs,
+ get_relation_kwargs,
+ get_url_kwargs,
)
from rest_framework.utils.serializer_helpers import (
- BindingDict, BoundField, JSONBoundField, NestedBoundField, ReturnDict,
- ReturnList
+ BindingDict,
+ BoundField,
+ JSONBoundField,
+ NestedBoundField,
+ ReturnDict,
+ ReturnList,
)
from rest_framework.validators import (
- UniqueForDateValidator, UniqueForMonthValidator, UniqueForYearValidator,
- UniqueTogetherValidator
+ UniqueForDateValidator,
+ UniqueForMonthValidator,
+ UniqueForYearValidator,
+ UniqueTogetherValidator,
)
+
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
@@ -52,37 +63,84 @@ from rest_framework.validators import (
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.fields import ( # NOQA # isort:skip
- BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField,
- DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField,
- HiddenField, HStoreField, IPAddressField, ImageField, IntegerField, JSONField,
- ListField, ModelField, MultipleChoiceField, NullBooleanField, ReadOnlyField,
- RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField,
+ BooleanField,
+ CharField,
+ ChoiceField,
+ DateField,
+ DateTimeField,
+ DecimalField,
+ DictField,
+ DurationField,
+ EmailField,
+ Field,
+ FileField,
+ FilePathField,
+ FloatField,
+ HiddenField,
+ HStoreField,
+ IPAddressField,
+ ImageField,
+ IntegerField,
+ JSONField,
+ ListField,
+ ModelField,
+ MultipleChoiceField,
+ NullBooleanField,
+ ReadOnlyField,
+ RegexField,
+ SerializerMethodField,
+ SlugField,
+ TimeField,
+ URLField,
+ UUIDField,
)
from rest_framework.relations import ( # NOQA # isort:skip
- HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField,
- PrimaryKeyRelatedField, RelatedField, SlugRelatedField, StringRelatedField,
+ HyperlinkedIdentityField,
+ HyperlinkedRelatedField,
+ ManyRelatedField,
+ PrimaryKeyRelatedField,
+ RelatedField,
+ SlugRelatedField,
+ StringRelatedField,
)
# Non-field imports, but public API
from rest_framework.fields import ( # NOQA # isort:skip
- CreateOnlyDefault, CurrentUserDefault, SkipField, empty
+ CreateOnlyDefault,
+ CurrentUserDefault,
+ SkipField,
+ empty,
)
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = (
- 'read_only', 'write_only', 'required', 'default', 'initial', 'source',
- 'label', 'help_text', 'style', 'error_messages', 'allow_empty',
- 'instance', 'data', 'partial', 'context', 'allow_null'
+ "read_only",
+ "write_only",
+ "required",
+ "default",
+ "initial",
+ "source",
+ "label",
+ "help_text",
+ "style",
+ "error_messages",
+ "allow_empty",
+ "instance",
+ "data",
+ "partial",
+ "context",
+ "allow_null",
)
-ALL_FIELDS = '__all__'
+ALL_FIELDS = "__all__"
# BaseSerializer
# --------------
+
class BaseSerializer(Field):
"""
The BaseSerializer class provides a minimal class which may be used
@@ -112,15 +170,15 @@ class BaseSerializer(Field):
self.instance = instance
if data is not empty:
self.initial_data = data
- self.partial = kwargs.pop('partial', False)
- self._context = kwargs.pop('context', {})
- kwargs.pop('many', None)
+ self.partial = kwargs.pop("partial", False)
+ self._context = kwargs.pop("context", {})
+ kwargs.pop("many", None)
super(BaseSerializer, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ListSerializer` classes instead when `many=True` is set.
- if kwargs.pop('many', False):
+ if kwargs.pop("many", False):
return cls.many_init(*args, **kwargs)
return super(BaseSerializer, cls).__new__(cls, *args, **kwargs)
@@ -141,51 +199,52 @@ class BaseSerializer(Field):
kwargs['child'] = cls()
return CustomListSerializer(*args, **kwargs)
"""
- allow_empty = kwargs.pop('allow_empty', None)
+ allow_empty = kwargs.pop("allow_empty", None)
child_serializer = cls(*args, **kwargs)
- list_kwargs = {
- 'child': child_serializer,
- }
+ list_kwargs = {"child": child_serializer}
if allow_empty is not None:
- list_kwargs['allow_empty'] = allow_empty
- list_kwargs.update({
- key: value for key, value in kwargs.items()
- if key in LIST_SERIALIZER_KWARGS
- })
- meta = getattr(cls, 'Meta', None)
- list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer)
+ list_kwargs["allow_empty"] = allow_empty
+ list_kwargs.update(
+ {
+ key: value
+ for key, value in kwargs.items()
+ if key in LIST_SERIALIZER_KWARGS
+ }
+ )
+ meta = getattr(cls, "Meta", None)
+ list_serializer_class = getattr(meta, "list_serializer_class", ListSerializer)
return list_serializer_class(*args, **list_kwargs)
def to_internal_value(self, data):
- raise NotImplementedError('`to_internal_value()` must be implemented.')
+ raise NotImplementedError("`to_internal_value()` must be implemented.")
def to_representation(self, instance):
- raise NotImplementedError('`to_representation()` must be implemented.')
+ raise NotImplementedError("`to_representation()` must be implemented.")
def update(self, instance, validated_data):
- raise NotImplementedError('`update()` must be implemented.')
+ raise NotImplementedError("`update()` must be implemented.")
def create(self, validated_data):
- raise NotImplementedError('`create()` must be implemented.')
+ raise NotImplementedError("`create()` must be implemented.")
def save(self, **kwargs):
- assert not hasattr(self, 'save_object'), (
- 'Serializer `%s.%s` has old-style version 2 `.save_object()` '
- 'that is no longer compatible with REST framework 3. '
- 'Use the new-style `.create()` and `.update()` methods instead.' %
- (self.__class__.__module__, self.__class__.__name__)
+ assert not hasattr(self, "save_object"), (
+ "Serializer `%s.%s` has old-style version 2 `.save_object()` "
+ "that is no longer compatible with REST framework 3. "
+ "Use the new-style `.create()` and `.update()` methods instead."
+ % (self.__class__.__module__, self.__class__.__name__)
)
- assert hasattr(self, '_errors'), (
- 'You must call `.is_valid()` before calling `.save()`.'
- )
+ assert hasattr(
+ self, "_errors"
+ ), "You must call `.is_valid()` before calling `.save()`."
- assert not self.errors, (
- 'You cannot call `.save()` on a serializer with invalid data.'
- )
+ assert (
+ not self.errors
+ ), "You cannot call `.save()` on a serializer with invalid data."
# Guard against incorrect use of `serializer.save(commit=False)`
- assert 'commit' not in kwargs, (
+ assert "commit" not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
@@ -194,44 +253,41 @@ class BaseSerializer(Field):
"For example: 'serializer.save(owner=request.user)'.'"
)
- assert not hasattr(self, '_data'), (
+ assert not hasattr(self, "_data"), (
"You cannot call `.save()` after accessing `serializer.data`."
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
)
- validated_data = dict(
- list(self.validated_data.items()) +
- list(kwargs.items())
- )
+ validated_data = dict(list(self.validated_data.items()) + list(kwargs.items()))
if self.instance is not None:
self.instance = self.update(self.instance, validated_data)
- assert self.instance is not None, (
- '`update()` did not return an object instance.'
- )
+ assert (
+ self.instance is not None
+ ), "`update()` did not return an object instance."
else:
self.instance = self.create(validated_data)
- assert self.instance is not None, (
- '`create()` did not return an object instance.'
- )
+ assert (
+ self.instance is not None
+ ), "`create()` did not return an object instance."
return self.instance
def is_valid(self, raise_exception=False):
- assert not hasattr(self, 'restore_object'), (
- 'Serializer `%s.%s` has old-style version 2 `.restore_object()` '
- 'that is no longer compatible with REST framework 3. '
- 'Use the new-style `.create()` and `.update()` methods instead.' %
- (self.__class__.__module__, self.__class__.__name__)
+ assert not hasattr(self, "restore_object"), (
+ "Serializer `%s.%s` has old-style version 2 `.restore_object()` "
+ "that is no longer compatible with REST framework 3. "
+ "Use the new-style `.create()` and `.update()` methods instead."
+ % (self.__class__.__module__, self.__class__.__name__)
)
- assert hasattr(self, 'initial_data'), (
- 'Cannot call `.is_valid()` as no `data=` keyword argument was '
- 'passed when instantiating the serializer instance.'
+ assert hasattr(self, "initial_data"), (
+ "Cannot call `.is_valid()` as no `data=` keyword argument was "
+ "passed when instantiating the serializer instance."
)
- if not hasattr(self, '_validated_data'):
+ if not hasattr(self, "_validated_data"):
try:
self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc:
@@ -247,20 +303,22 @@ class BaseSerializer(Field):
@property
def data(self):
- if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'):
+ if hasattr(self, "initial_data") and not hasattr(self, "_validated_data"):
msg = (
- 'When a serializer is passed a `data` keyword argument you '
- 'must call `.is_valid()` before attempting to access the '
- 'serialized `.data` representation.\n'
- 'You should either call `.is_valid()` first, '
- 'or access `.initial_data` instead.'
+ "When a serializer is passed a `data` keyword argument you "
+ "must call `.is_valid()` before attempting to access the "
+ "serialized `.data` representation.\n"
+ "You should either call `.is_valid()` first, "
+ "or access `.initial_data` instead."
)
raise AssertionError(msg)
- if not hasattr(self, '_data'):
- if self.instance is not None and not getattr(self, '_errors', None):
+ if not hasattr(self, "_data"):
+ if self.instance is not None and not getattr(self, "_errors", None):
self._data = self.to_representation(self.instance)
- elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None):
+ elif hasattr(self, "_validated_data") and not getattr(
+ self, "_errors", None
+ ):
self._data = self.to_representation(self.validated_data)
else:
self._data = self.get_initial()
@@ -268,15 +326,15 @@ class BaseSerializer(Field):
@property
def errors(self):
- if not hasattr(self, '_errors'):
- msg = 'You must call `.is_valid()` before accessing `.errors`.'
+ if not hasattr(self, "_errors"):
+ msg = "You must call `.is_valid()` before accessing `.errors`."
raise AssertionError(msg)
return self._errors
@property
def validated_data(self):
- if not hasattr(self, '_validated_data'):
- msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
+ if not hasattr(self, "_validated_data"):
+ msg = "You must call `.is_valid()` before accessing `.validated_data`."
raise AssertionError(msg)
return self._validated_data
@@ -284,6 +342,7 @@ class BaseSerializer(Field):
# Serializer & ListSerializer classes
# -----------------------------------
+
class SerializerMetaclass(type):
"""
This metaclass sets a dictionary named `_declared_fields` on the class.
@@ -295,26 +354,28 @@ class SerializerMetaclass(type):
@classmethod
def _get_declared_fields(cls, bases, attrs):
- fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in list(attrs.items())
- if isinstance(obj, Field)]
+ fields = [
+ (field_name, attrs.pop(field_name))
+ for field_name, obj in list(attrs.items())
+ if isinstance(obj, Field)
+ ]
fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in reversed(bases):
- if hasattr(base, '_declared_fields'):
+ if hasattr(base, "_declared_fields"):
fields = [
- (field_name, obj) for field_name, obj
- in base._declared_fields.items()
+ (field_name, obj)
+ for field_name, obj in base._declared_fields.items()
if field_name not in attrs
] + fields
return OrderedDict(fields)
def __new__(cls, name, bases, attrs):
- attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
+ attrs["_declared_fields"] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@@ -335,19 +396,15 @@ def as_serializer_error(exc):
}
elif isinstance(detail, list):
# Errors raised as a list are non-field errors.
- return {
- api_settings.NON_FIELD_ERRORS_KEY: detail
- }
+ return {api_settings.NON_FIELD_ERRORS_KEY: detail}
# Errors raised as a string are non-field errors.
- return {
- api_settings.NON_FIELD_ERRORS_KEY: [detail]
- }
+ return {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
default_error_messages = {
- 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
+ "invalid": _("Invalid data. Expected a dictionary, but got {datatype}.")
}
@property
@@ -358,7 +415,7 @@ class Serializer(BaseSerializer):
# `fields` is evaluated lazily. We do this to ensure that we don't
# have issues importing modules that use ModelSerializers as fields,
# even if Django's app-loading stage has not yet run.
- if not hasattr(self, '_fields'):
+ if not hasattr(self, "_fields"):
self._fields = BindingDict(self)
for key, value in self.get_fields().items():
self._fields[key] = value
@@ -366,16 +423,11 @@ class Serializer(BaseSerializer):
@cached_property
def _writable_fields(self):
- return [
- field for field in self.fields.values() if not field.read_only
- ]
+ return [field for field in self.fields.values() if not field.read_only]
@cached_property
def _readable_fields(self):
- return [
- field for field in self.fields.values()
- if not field.write_only
- ]
+ return [field for field in self.fields.values() if not field.write_only]
def get_fields(self):
"""
@@ -391,28 +443,32 @@ class Serializer(BaseSerializer):
Returns a list of validator callables.
"""
# Used by the lazily-evaluated `validators` property.
- meta = getattr(self, 'Meta', None)
- validators = getattr(meta, 'validators', None)
+ meta = getattr(self, "Meta", None)
+ validators = getattr(meta, "validators", None)
return list(validators) if validators else []
def get_initial(self):
- if hasattr(self, 'initial_data'):
+ if hasattr(self, "initial_data"):
# initial_data may not be a valid type
if not isinstance(self.initial_data, Mapping):
return OrderedDict()
- return OrderedDict([
- (field_name, field.get_value(self.initial_data))
- for field_name, field in self.fields.items()
- if (field.get_value(self.initial_data) is not empty) and
- not field.read_only
- ])
+ return OrderedDict(
+ [
+ (field_name, field.get_value(self.initial_data))
+ for field_name, field in self.fields.items()
+ if (field.get_value(self.initial_data) is not empty)
+ and not field.read_only
+ ]
+ )
- return OrderedDict([
- (field.field_name, field.get_initial())
- for field in self.fields.values()
- if not field.read_only
- ])
+ return OrderedDict(
+ [
+ (field.field_name, field.get_initial())
+ for field in self.fields.values()
+ if not field.read_only
+ ]
+ )
def get_value(self, dictionary):
# We override the default field access in order to support
@@ -435,7 +491,7 @@ class Serializer(BaseSerializer):
try:
self.run_validators(value)
value = self.validate(value)
- assert value is not None, '.validate() should return the validated data'
+ assert value is not None, ".validate() should return the validated data"
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc))
@@ -443,8 +499,12 @@ class Serializer(BaseSerializer):
def _read_only_defaults(self):
fields = [
- field for field in self.fields.values()
- if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
+ field
+ for field in self.fields.values()
+ if (field.read_only)
+ and (field.default != empty)
+ and (field.source != "*")
+ and ("." not in field.source)
]
defaults = OrderedDict()
@@ -473,19 +533,19 @@ class Serializer(BaseSerializer):
Dict of native values <- Dict of primitive datatypes.
"""
if not isinstance(data, Mapping):
- message = self.error_messages['invalid'].format(
+ message = self.error_messages["invalid"].format(
datatype=type(data).__name__
)
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: [message]
- }, code='invalid')
+ raise ValidationError(
+ {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="invalid"
+ )
ret = OrderedDict()
errors = OrderedDict()
fields = self._writable_fields
for field in fields:
- validate_method = getattr(self, 'validate_' + field.field_name, None)
+ validate_method = getattr(self, "validate_" + field.field_name, None)
primitive_value = field.get_value(data)
try:
validated_value = field.run_validation(primitive_value)
@@ -523,7 +583,9 @@ class Serializer(BaseSerializer):
#
# For related fields with `use_pk_only_optimization` we need to
# resolve the pk value.
- check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
+ check_for_none = (
+ attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
+ )
if check_for_none is None:
ret[field.field_name] = None
else:
@@ -548,7 +610,7 @@ class Serializer(BaseSerializer):
def __getitem__(self, key):
field = self.fields[key]
value = self.data.get(key)
- error = self.errors.get(key) if hasattr(self, '_errors') else None
+ error = self.errors.get(key) if hasattr(self, "_errors") else None
if isinstance(field, Serializer):
return NestedBoundField(field, value, error)
if isinstance(field, JSONField):
@@ -566,10 +628,14 @@ class Serializer(BaseSerializer):
@property
def errors(self):
ret = super(Serializer, self).errors
- if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null':
+ if (
+ isinstance(ret, list)
+ and len(ret) == 1
+ and getattr(ret[0], "code", None) == "null"
+ ):
# Edge case. Provide a more descriptive error than
# "this field may not be null", when no data is passed.
- detail = ErrorDetail('No data provided', code='null')
+ detail = ErrorDetail("No data provided", code="null")
ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
return ReturnDict(ret, serializer=self)
@@ -577,29 +643,30 @@ class Serializer(BaseSerializer):
# There's some replication of `ListField` here,
# but that's probably better than obfuscating the call hierarchy.
+
class ListSerializer(BaseSerializer):
child = None
many = True
default_error_messages = {
- 'not_a_list': _('Expected a list of items but got type "{input_type}".'),
- 'empty': _('This list may not be empty.')
+ "not_a_list": _('Expected a list of items but got type "{input_type}".'),
+ "empty": _("This list may not be empty."),
}
def __init__(self, *args, **kwargs):
- self.child = kwargs.pop('child', copy.deepcopy(self.child))
- self.allow_empty = kwargs.pop('allow_empty', True)
- assert self.child is not None, '`child` is a required argument.'
- assert not inspect.isclass(self.child), '`child` has not been instantiated.'
+ self.child = kwargs.pop("child", copy.deepcopy(self.child))
+ self.allow_empty = kwargs.pop("allow_empty", True)
+ assert self.child is not None, "`child` is a required argument."
+ assert not inspect.isclass(self.child), "`child` has not been instantiated."
super(ListSerializer, self).__init__(*args, **kwargs)
- self.child.bind(field_name='', parent=self)
+ self.child.bind(field_name="", parent=self)
def bind(self, field_name, parent):
super(ListSerializer, self).bind(field_name, parent)
self.partial = self.parent.partial
def get_initial(self):
- if hasattr(self, 'initial_data'):
+ if hasattr(self, "initial_data"):
return self.to_representation(self.initial_data)
return []
@@ -610,7 +677,9 @@ class ListSerializer(BaseSerializer):
# We override the default field access in order to support
# lists in HTML forms.
if html.is_html_input(dictionary):
- return html.parse_html_list(dictionary, prefix=self.field_name, default=empty)
+ return html.parse_html_list(
+ dictionary, prefix=self.field_name, default=empty
+ )
return dictionary.get(self.field_name, empty)
def run_validation(self, data=empty):
@@ -627,7 +696,7 @@ class ListSerializer(BaseSerializer):
try:
self.run_validators(value)
value = self.validate(value)
- assert value is not None, '.validate() should return the validated data'
+ assert value is not None, ".validate() should return the validated data"
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc))
@@ -641,21 +710,21 @@ class ListSerializer(BaseSerializer):
data = html.parse_html_list(data, default=[])
if not isinstance(data, list):
- message = self.error_messages['not_a_list'].format(
+ message = self.error_messages["not_a_list"].format(
input_type=type(data).__name__
)
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: [message]
- }, code='not_a_list')
+ raise ValidationError(
+ {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="not_a_list"
+ )
if not self.allow_empty and len(data) == 0:
if self.parent and self.partial:
raise SkipField()
- message = self.error_messages['empty']
- raise ValidationError({
- api_settings.NON_FIELD_ERRORS_KEY: [message]
- }, code='empty')
+ message = self.error_messages["empty"]
+ raise ValidationError(
+ {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="empty"
+ )
ret = []
errors = []
@@ -682,9 +751,7 @@ class ListSerializer(BaseSerializer):
# so, first get a queryset from the Manager if needed
iterable = data.all() if isinstance(data, models.Manager) else data
- return [
- self.child.to_representation(item) for item in iterable
- ]
+ return [self.child.to_representation(item) for item in iterable]
def validate(self, attrs):
return attrs
@@ -699,16 +766,14 @@ class ListSerializer(BaseSerializer):
)
def create(self, validated_data):
- return [
- self.child.create(attrs) for attrs in validated_data
- ]
+ return [self.child.create(attrs) for attrs in validated_data]
def save(self, **kwargs):
"""
Save and return a list of object instances.
"""
# Guard against incorrect use of `serializer.save(commit=False)`
- assert 'commit' not in kwargs, (
+ assert "commit" not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
@@ -724,26 +789,26 @@ class ListSerializer(BaseSerializer):
if self.instance is not None:
self.instance = self.update(self.instance, validated_data)
- assert self.instance is not None, (
- '`update()` did not return an object instance.'
- )
+ assert (
+ self.instance is not None
+ ), "`update()` did not return an object instance."
else:
self.instance = self.create(validated_data)
- assert self.instance is not None, (
- '`create()` did not return an object instance.'
- )
+ assert (
+ self.instance is not None
+ ), "`create()` did not return an object instance."
return self.instance
def is_valid(self, raise_exception=False):
# This implementation is the same as the default,
# except that we use lists, rather than dicts, as the empty case.
- assert hasattr(self, 'initial_data'), (
- 'Cannot call `.is_valid()` as no `data=` keyword argument was '
- 'passed when instantiating the serializer instance.'
+ assert hasattr(self, "initial_data"), (
+ "Cannot call `.is_valid()` as no `data=` keyword argument was "
+ "passed when instantiating the serializer instance."
)
- if not hasattr(self, '_validated_data'):
+ if not hasattr(self, "_validated_data"):
try:
self._validated_data = self.run_validation(self.initial_data)
except ValidationError as exc:
@@ -771,10 +836,14 @@ class ListSerializer(BaseSerializer):
@property
def errors(self):
ret = super(ListSerializer, self).errors
- if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null':
+ if (
+ isinstance(ret, list)
+ and len(ret) == 1
+ and getattr(ret[0], "code", None) == "null"
+ ):
# Edge case. Provide a more descriptive error than
# "this field may not be null", when no data is passed.
- detail = ErrorDetail('No data provided', code='null')
+ detail = ErrorDetail("No data provided", code="null")
ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]}
if isinstance(ret, dict):
return ReturnDict(ret, serializer=self)
@@ -784,6 +853,7 @@ class ListSerializer(BaseSerializer):
# ModelSerializer & HyperlinkedModelSerializer
# --------------------------------------------
+
def raise_errors_on_nested_writes(method_name, serializer, validated_data):
"""
Give explicit errors when users attempt to pass writable nested data.
@@ -809,18 +879,18 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
# ...
# profile = ProfileSerializer()
assert not any(
- isinstance(field, BaseSerializer) and
- (field.source in validated_data) and
- isinstance(validated_data[field.source], (list, dict))
+ isinstance(field, BaseSerializer)
+ and (field.source in validated_data)
+ and isinstance(validated_data[field.source], (list, dict))
for field in serializer._writable_fields
), (
- 'The `.{method_name}()` method does not support writable nested '
- 'fields by default.\nWrite an explicit `.{method_name}()` method for '
- 'serializer `{module}.{class_name}`, or set `read_only=True` on '
- 'nested serializer fields.'.format(
+ "The `.{method_name}()` method does not support writable nested "
+ "fields by default.\nWrite an explicit `.{method_name}()` method for "
+ "serializer `{module}.{class_name}`, or set `read_only=True` on "
+ "nested serializer fields.".format(
method_name=method_name,
module=serializer.__class__.__module__,
- class_name=serializer.__class__.__name__
+ class_name=serializer.__class__.__name__,
)
)
@@ -830,18 +900,18 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
# ...
# address = serializer.CharField('profile.address')
assert not any(
- '.' in field.source and
- (key in validated_data) and
- isinstance(validated_data[key], (list, dict))
+ "." in field.source
+ and (key in validated_data)
+ and isinstance(validated_data[key], (list, dict))
for key, field in serializer.fields.items()
), (
- 'The `.{method_name}()` method does not support writable dotted-source '
- 'fields by default.\nWrite an explicit `.{method_name}()` method for '
- 'serializer `{module}.{class_name}`, or set `read_only=True` on '
- 'dotted-source serializer fields.'.format(
+ "The `.{method_name}()` method does not support writable dotted-source "
+ "fields by default.\nWrite an explicit `.{method_name}()` method for "
+ "serializer `{module}.{class_name}`, or set `read_only=True` on "
+ "dotted-source serializer fields.".format(
method_name=method_name,
module=serializer.__class__.__module__,
- class_name=serializer.__class__.__name__
+ class_name=serializer.__class__.__name__,
)
)
@@ -862,6 +932,7 @@ class ModelSerializer(Serializer):
you need you should either declare the extra/differing fields explicitly on
the serializer class, or simply use a `Serializer` class.
"""
+
serializer_field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
@@ -926,7 +997,7 @@ class ModelSerializer(Serializer):
If you want to support writable nested relationships you'll need
to write an explicit `.create()` method.
"""
- raise_errors_on_nested_writes('create', self, validated_data)
+ raise_errors_on_nested_writes("create", self, validated_data)
ModelClass = self.Meta.model
@@ -944,19 +1015,19 @@ class ModelSerializer(Serializer):
except TypeError:
tb = traceback.format_exc()
msg = (
- 'Got a `TypeError` when calling `%s.%s.create()`. '
- 'This may be because you have a writable field on the '
- 'serializer class that is not a valid argument to '
- '`%s.%s.create()`. You may need to make the field '
- 'read-only, or override the %s.create() method to handle '
- 'this correctly.\nOriginal exception was:\n %s' %
- (
+ "Got a `TypeError` when calling `%s.%s.create()`. "
+ "This may be because you have a writable field on the "
+ "serializer class that is not a valid argument to "
+ "`%s.%s.create()`. You may need to make the field "
+ "read-only, or override the %s.create() method to handle "
+ "this correctly.\nOriginal exception was:\n %s"
+ % (
ModelClass.__name__,
ModelClass._default_manager.name,
ModelClass.__name__,
ModelClass._default_manager.name,
self.__class__.__name__,
- tb
+ tb,
)
)
raise TypeError(msg)
@@ -970,7 +1041,7 @@ class ModelSerializer(Serializer):
return instance
def update(self, instance, validated_data):
- raise_errors_on_nested_writes('update', self, validated_data)
+ raise_errors_on_nested_writes("update", self, validated_data)
info = model_meta.get_field_info(instance)
# Simply set each attribute on the instance, and then save it.
@@ -997,24 +1068,22 @@ class ModelSerializer(Serializer):
if self.url_field_name is None:
self.url_field_name = api_settings.URL_FIELD_NAME
- assert hasattr(self, 'Meta'), (
- 'Class {serializer_class} missing "Meta" attribute'.format(
- serializer_class=self.__class__.__name__
- )
+ assert hasattr(
+ self, "Meta"
+ ), 'Class {serializer_class} missing "Meta" attribute'.format(
+ serializer_class=self.__class__.__name__
)
- assert hasattr(self.Meta, 'model'), (
- 'Class {serializer_class} missing "Meta.model" attribute'.format(
- serializer_class=self.__class__.__name__
- )
+ assert hasattr(
+ self.Meta, "model"
+ ), 'Class {serializer_class} missing "Meta.model" attribute'.format(
+ serializer_class=self.__class__.__name__
)
if model_meta.is_abstract_model(self.Meta.model):
- raise ValueError(
- 'Cannot use ModelSerializer with Abstract Models.'
- )
+ raise ValueError("Cannot use ModelSerializer with Abstract Models.")
declared_fields = copy.deepcopy(self._declared_fields)
- model = getattr(self.Meta, 'model')
- depth = getattr(self.Meta, 'depth', 0)
+ model = getattr(self.Meta, "model")
+ depth = getattr(self.Meta, "depth", 0)
if depth is not None:
assert depth >= 0, "'depth' may not be negative."
@@ -1041,19 +1110,15 @@ class ModelSerializer(Serializer):
continue
extra_field_kwargs = extra_kwargs.get(field_name, {})
- source = extra_field_kwargs.get('source', '*')
- if source == '*':
+ source = extra_field_kwargs.get("source", "*")
+ if source == "*":
source = field_name
# Determine the serializer field class and keyword arguments.
- field_class, field_kwargs = self.build_field(
- source, info, model, depth
- )
+ field_class, field_kwargs = self.build_field(source, info, model, depth)
# Include any kwargs defined in `Meta.extra_kwargs`
- field_kwargs = self.include_extra_kwargs(
- field_kwargs, extra_field_kwargs
- )
+ field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs)
# Create the serializer field.
fields[field_name] = field_class(**field_kwargs)
@@ -1072,19 +1137,19 @@ class ModelSerializer(Serializer):
set of fields, but also takes into account the `Meta.fields` or
`Meta.exclude` options if they have been specified.
"""
- fields = getattr(self.Meta, 'fields', None)
- exclude = getattr(self.Meta, 'exclude', None)
+ fields = getattr(self.Meta, "fields", None)
+ exclude = getattr(self.Meta, "exclude", None)
if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)):
raise TypeError(
'The `fields` option must be a list or tuple or "__all__". '
- 'Got %s.' % type(fields).__name__
+ "Got %s." % type(fields).__name__
)
if exclude and not isinstance(exclude, (list, tuple)):
raise TypeError(
- 'The `exclude` option must be a list or tuple. Got %s.' %
- type(exclude).__name__
+ "The `exclude` option must be a list or tuple. Got %s."
+ % type(exclude).__name__
)
assert not (fields and exclude), (
@@ -1115,15 +1180,14 @@ class ModelSerializer(Serializer):
# a subset of fields.
required_field_names = set(declared_fields)
for cls in self.__class__.__bases__:
- required_field_names -= set(getattr(cls, '_declared_fields', []))
+ required_field_names -= set(getattr(cls, "_declared_fields", []))
for field_name in required_field_names:
assert field_name in fields, (
"The field '{field_name}' was declared on serializer "
"{serializer_class}, but has not been included in the "
"'fields' option.".format(
- field_name=field_name,
- serializer_class=self.__class__.__name__
+ field_name=field_name, serializer_class=self.__class__.__name__
)
)
return fields
@@ -1138,10 +1202,8 @@ class ModelSerializer(Serializer):
"Cannot both declare the field '{field_name}' and include "
"it in the {serializer_class} 'exclude' option. Remove the "
"field or, if inherited from a parent serializer, disable "
- "with `{field_name} = None`."
- .format(
- field_name=field_name,
- serializer_class=self.__class__.__name__
+ "with `{field_name} = None`.".format(
+ field_name=field_name, serializer_class=self.__class__.__name__
)
)
@@ -1149,8 +1211,7 @@ class ModelSerializer(Serializer):
"The field '{field_name}' was included on serializer "
"{serializer_class} in the 'exclude' option, but does "
"not match any model field.".format(
- field_name=field_name,
- serializer_class=self.__class__.__name__
+ field_name=field_name, serializer_class=self.__class__.__name__
)
)
fields.remove(field_name)
@@ -1163,10 +1224,10 @@ class ModelSerializer(Serializer):
`Meta.fields` option is not specified.
"""
return (
- [model_info.pk.name] +
- list(declared_fields) +
- list(model_info.fields) +
- list(model_info.forward_relations)
+ [model_info.pk.name]
+ + list(declared_fields)
+ + list(model_info.fields)
+ + list(model_info.forward_relations)
)
# Methods for constructing serializer fields...
@@ -1206,9 +1267,9 @@ class ModelSerializer(Serializer):
# Special case to handle when a OneToOneField is also the primary key
if model_field.one_to_one and model_field.primary_key:
field_class = self.serializer_related_field
- field_kwargs['queryset'] = model_field.related_model.objects
+ field_kwargs["queryset"] = model_field.related_model.objects
- if 'choices' in field_kwargs:
+ if "choices" in field_kwargs:
# Fields with choices get coerced into `ChoiceField`
# instead of using their regular typed field.
field_class = self.serializer_choice_field
@@ -1216,11 +1277,20 @@ class ModelSerializer(Serializer):
# for the choice field. We need to strip these out.
# Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES)
valid_kwargs = {
- 'read_only', 'write_only',
- 'required', 'default', 'initial', 'source',
- 'label', 'help_text', 'style',
- 'error_messages', 'validators', 'allow_null', 'allow_blank',
- 'choices'
+ "read_only",
+ "write_only",
+ "required",
+ "default",
+ "initial",
+ "source",
+ "label",
+ "help_text",
+ "style",
+ "error_messages",
+ "validators",
+ "allow_null",
+ "allow_blank",
+ "choices",
}
for key in list(field_kwargs):
if key not in valid_kwargs:
@@ -1230,20 +1300,22 @@ class ModelSerializer(Serializer):
# `model_field` is only valid for the fallback case of
# `ModelField`, which is used when no other typed field
# matched to the model field.
- field_kwargs.pop('model_field', None)
+ field_kwargs.pop("model_field", None)
- if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField):
+ if not issubclass(field_class, CharField) and not issubclass(
+ field_class, ChoiceField
+ ):
# `allow_blank` is only valid for textual fields.
- field_kwargs.pop('allow_blank', None)
+ field_kwargs.pop("allow_blank", None)
if postgres_fields and isinstance(model_field, postgres_fields.ArrayField):
# Populate the `child` argument on `ListField` instances generated
# for the PostgreSQL specific `ArrayField`.
child_model_field = model_field.base_field
child_field_class, child_field_kwargs = self.build_standard_field(
- 'child', child_model_field
+ "child", child_model_field
)
- field_kwargs['child'] = child_field_class(**child_field_kwargs)
+ field_kwargs["child"] = child_field_class(**child_field_kwargs)
return field_class, field_kwargs
@@ -1254,14 +1326,18 @@ class ModelSerializer(Serializer):
field_class = self.serializer_related_field
field_kwargs = get_relation_kwargs(field_name, relation_info)
- to_field = field_kwargs.pop('to_field', None)
- if to_field and not relation_info.reverse and not relation_info.related_model._meta.get_field(to_field).primary_key:
- field_kwargs['slug_field'] = to_field
+ to_field = field_kwargs.pop("to_field", None)
+ if (
+ to_field
+ and not relation_info.reverse
+ and not relation_info.related_model._meta.get_field(to_field).primary_key
+ ):
+ field_kwargs["slug_field"] = to_field
field_class = self.serializer_related_to_field
# `view_name` is only valid for hyperlinked relationships.
if not issubclass(field_class, HyperlinkedRelatedField):
- field_kwargs.pop('view_name', None)
+ field_kwargs.pop("view_name", None)
return field_class, field_kwargs
@@ -1269,11 +1345,12 @@ class ModelSerializer(Serializer):
"""
Create nested fields for forward and reverse relationships.
"""
+
class NestedSerializer(ModelSerializer):
class Meta:
model = relation_info.related_model
depth = nested_depth - 1
- fields = '__all__'
+ fields = "__all__"
field_class = NestedSerializer
field_kwargs = get_nested_relation_kwargs(relation_info)
@@ -1303,8 +1380,8 @@ class ModelSerializer(Serializer):
Raise an error on any unknown fields.
"""
raise ImproperlyConfigured(
- 'Field name `%s` is not valid for model `%s`.' %
- (field_name, model_class.__name__)
+ "Field name `%s` is not valid for model `%s`."
+ % (field_name, model_class.__name__)
)
def include_extra_kwargs(self, kwargs, extra_kwargs):
@@ -1312,19 +1389,28 @@ class ModelSerializer(Serializer):
Include any 'extra_kwargs' that have been included for this field,
possibly removing any incompatible existing keyword arguments.
"""
- if extra_kwargs.get('read_only', False):
+ if extra_kwargs.get("read_only", False):
for attr in [
- 'required', 'default', 'allow_blank', 'allow_null',
- 'min_length', 'max_length', 'min_value', 'max_value',
- 'validators', 'queryset'
+ "required",
+ "default",
+ "allow_blank",
+ "allow_null",
+ "min_length",
+ "max_length",
+ "min_value",
+ "max_value",
+ "validators",
+ "queryset",
]:
kwargs.pop(attr, None)
- if extra_kwargs.get('default') and kwargs.get('required') is False:
- kwargs.pop('required')
+ if extra_kwargs.get("default") and kwargs.get("required") is False:
+ kwargs.pop("required")
- if extra_kwargs.get('read_only', kwargs.get('read_only', False)):
- extra_kwargs.pop('required', None) # Read only fields should always omit the 'required' argument.
+ if extra_kwargs.get("read_only", kwargs.get("read_only", False)):
+ extra_kwargs.pop(
+ "required", None
+ ) # Read only fields should always omit the 'required' argument.
kwargs.update(extra_kwargs)
@@ -1337,27 +1423,27 @@ class ModelSerializer(Serializer):
Return a dictionary mapping field names to a dictionary of
additional keyword arguments.
"""
- extra_kwargs = copy.deepcopy(getattr(self.Meta, 'extra_kwargs', {}))
+ extra_kwargs = copy.deepcopy(getattr(self.Meta, "extra_kwargs", {}))
- read_only_fields = getattr(self.Meta, 'read_only_fields', None)
+ read_only_fields = getattr(self.Meta, "read_only_fields", None)
if read_only_fields is not None:
if not isinstance(read_only_fields, (list, tuple)):
raise TypeError(
- 'The `read_only_fields` option must be a list or tuple. '
- 'Got %s.' % type(read_only_fields).__name__
+ "The `read_only_fields` option must be a list or tuple. "
+ "Got %s." % type(read_only_fields).__name__
)
for field_name in read_only_fields:
kwargs = extra_kwargs.get(field_name, {})
- kwargs['read_only'] = True
+ kwargs["read_only"] = True
extra_kwargs[field_name] = kwargs
else:
# Guard against the possible misspelling `readonly_fields` (used
# by the Django admin and others).
- assert not hasattr(self.Meta, 'readonly_fields'), (
- 'Serializer `%s.%s` has field `readonly_fields`; '
- 'the correct spelling for the option is `read_only_fields`.' %
- (self.__class__.__module__, self.__class__.__name__)
+ assert not hasattr(self.Meta, "readonly_fields"), (
+ "Serializer `%s.%s` has field `readonly_fields`; "
+ "the correct spelling for the option is `read_only_fields`."
+ % (self.__class__.__module__, self.__class__.__name__)
)
return extra_kwargs
@@ -1370,10 +1456,10 @@ class ModelSerializer(Serializer):
('dict of updated extra kwargs', 'mapping of hidden fields')
"""
- if getattr(self.Meta, 'validators', None) is not None:
+ if getattr(self.Meta, "validators", None) is not None:
return (extra_kwargs, {})
- model = getattr(self.Meta, 'model')
+ model = getattr(self.Meta, "model")
model_fields = self._get_model_fields(
field_names, declared_fields, extra_kwargs
)
@@ -1385,8 +1471,11 @@ class ModelSerializer(Serializer):
for model_field in model_fields.values():
# Include each of the `unique_for_*` field names.
- unique_constraint_names |= {model_field.unique_for_date, model_field.unique_for_month,
- model_field.unique_for_year}
+ unique_constraint_names |= {
+ model_field.unique_for_date,
+ model_field.unique_for_month,
+ model_field.unique_for_year,
+ }
unique_constraint_names -= {None}
@@ -1407,9 +1496,9 @@ class ModelSerializer(Serializer):
# Get the model field that is referred too.
unique_constraint_field = model._meta.get_field(unique_constraint_name)
- if getattr(unique_constraint_field, 'auto_now_add', None):
+ if getattr(unique_constraint_field, "auto_now_add", None):
default = CreateOnlyDefault(timezone.now)
- elif getattr(unique_constraint_field, 'auto_now', None):
+ elif getattr(unique_constraint_field, "auto_now", None):
default = timezone.now
elif unique_constraint_field.has_default():
default = unique_constraint_field.default
@@ -1419,9 +1508,11 @@ class ModelSerializer(Serializer):
if unique_constraint_name in model_fields:
# The corresponding field is present in the serializer
if default is empty:
- uniqueness_extra_kwargs[unique_constraint_name] = {'required': True}
+ uniqueness_extra_kwargs[unique_constraint_name] = {"required": True}
else:
- uniqueness_extra_kwargs[unique_constraint_name] = {'default': default}
+ uniqueness_extra_kwargs[unique_constraint_name] = {
+ "default": default
+ }
elif default is not empty:
# The corresponding field is not present in the
# serializer. We have a default to use for it, so
@@ -1443,7 +1534,7 @@ class ModelSerializer(Serializer):
Returned as a dict of 'model field name' -> 'model field'.
Used internally by `get_uniqueness_field_options`.
"""
- model = getattr(self.Meta, 'model')
+ model = getattr(self.Meta, "model")
model_fields = {}
for field_name in field_names:
@@ -1453,11 +1544,11 @@ class ModelSerializer(Serializer):
source = field.source or field_name
else:
try:
- source = extra_kwargs[field_name]['source']
+ source = extra_kwargs[field_name]["source"]
except KeyError:
source = field_name
- if '.' in source or source == '*':
+ if "." in source or source == "*":
# Model fields will always have a simple source mapping,
# they can't be nested attribute lookups.
continue
@@ -1478,23 +1569,22 @@ class ModelSerializer(Serializer):
Determine the set of validators to use when instantiating serializer.
"""
# If the validators have been declared explicitly then use that.
- validators = getattr(getattr(self, 'Meta', None), 'validators', None)
+ validators = getattr(getattr(self, "Meta", None), "validators", None)
if validators is not None:
return list(validators)
# Otherwise use the default set of validators.
return (
- self.get_unique_together_validators() +
- self.get_unique_for_date_validators()
+ self.get_unique_together_validators()
+ + self.get_unique_for_date_validators()
)
def get_unique_together_validators(self):
"""
Determine a default set of validators for any unique_together constraints.
"""
- model_class_inheritance_tree = (
- [self.Meta.model] +
- list(self.Meta.model._meta.parents)
+ model_class_inheritance_tree = [self.Meta.model] + list(
+ self.Meta.model._meta.parents
)
# The field names we're passing though here only include fields
@@ -1502,14 +1592,19 @@ class ModelSerializer(Serializer):
# cannot map to a field, and must be a traversal, so we're not
# including those.
field_names = {
- field.source for field in self._writable_fields
- if (field.source != '*') and ('.' not in field.source)
+ field.source
+ for field in self._writable_fields
+ if (field.source != "*") and ("." not in field.source)
}
# Special Case: Add read_only fields with defaults.
field_names |= {
- field.source for field in self.fields.values()
- if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
+ field.source
+ for field in self.fields.values()
+ if (field.read_only)
+ and (field.default != empty)
+ and (field.source != "*")
+ and ("." not in field.source)
}
# Note that we make sure to check `unique_together` both on the
@@ -1519,8 +1614,7 @@ class ModelSerializer(Serializer):
for unique_together in parent_class._meta.unique_together:
if field_names.issuperset(set(unique_together)):
validator = UniqueTogetherValidator(
- queryset=parent_class._default_manager,
- fields=unique_together
+ queryset=parent_class._default_manager, fields=unique_together
)
validators.append(validator)
return validators
@@ -1544,7 +1638,7 @@ class ModelSerializer(Serializer):
validator = UniqueForDateValidator(
queryset=default_manager,
field=field_name,
- date_field=field.unique_for_date
+ date_field=field.unique_for_date,
)
validators.append(validator)
@@ -1552,7 +1646,7 @@ class ModelSerializer(Serializer):
validator = UniqueForMonthValidator(
queryset=default_manager,
field=field_name,
- date_field=field.unique_for_month
+ date_field=field.unique_for_month,
)
validators.append(validator)
@@ -1560,18 +1654,18 @@ class ModelSerializer(Serializer):
validator = UniqueForYearValidator(
queryset=default_manager,
field=field_name,
- date_field=field.unique_for_year
+ date_field=field.unique_for_year,
)
validators.append(validator)
return validators
-if hasattr(models, 'UUIDField'):
+if hasattr(models, "UUIDField"):
ModelSerializer.serializer_field_mapping[models.UUIDField] = UUIDField
# IPAddressField is deprecated in Django
-if hasattr(models, 'IPAddressField'):
+if hasattr(models, "IPAddressField"):
ModelSerializer.serializer_field_mapping[models.IPAddressField] = IPAddressField
if postgres_fields:
@@ -1588,6 +1682,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
* A 'url' field is included instead of the 'id' field.
* Relationships to other instances are hyperlinks, instead of primary keys.
"""
+
serializer_related_field = HyperlinkedRelatedField
def get_default_field_names(self, declared_fields, model_info):
@@ -1596,21 +1691,22 @@ class HyperlinkedModelSerializer(ModelSerializer):
`Meta.fields` option is not specified.
"""
return (
- [self.url_field_name] +
- list(declared_fields) +
- list(model_info.fields) +
- list(model_info.forward_relations)
+ [self.url_field_name]
+ + list(declared_fields)
+ + list(model_info.fields)
+ + list(model_info.forward_relations)
)
def build_nested_field(self, field_name, relation_info, nested_depth):
"""
Create nested fields for forward and reverse relationships.
"""
+
class NestedSerializer(HyperlinkedModelSerializer):
class Meta:
model = relation_info.related_model
depth = nested_depth - 1
- fields = '__all__'
+ fields = "__all__"
field_class = NestedSerializer
field_kwargs = get_nested_relation_kwargs(relation_info)
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 8db9c81ed..9fd09a9b1 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -28,135 +28,109 @@ from django.utils import six
from rest_framework import ISO_8601
+
DEFAULTS = {
# Base API policies
- 'DEFAULT_RENDERER_CLASSES': (
- 'rest_framework.renderers.JSONRenderer',
- 'rest_framework.renderers.BrowsableAPIRenderer',
+ "DEFAULT_RENDERER_CLASSES": (
+ "rest_framework.renderers.JSONRenderer",
+ "rest_framework.renderers.BrowsableAPIRenderer",
),
- 'DEFAULT_PARSER_CLASSES': (
- 'rest_framework.parsers.JSONParser',
- 'rest_framework.parsers.FormParser',
- 'rest_framework.parsers.MultiPartParser'
+ "DEFAULT_PARSER_CLASSES": (
+ "rest_framework.parsers.JSONParser",
+ "rest_framework.parsers.FormParser",
+ "rest_framework.parsers.MultiPartParser",
),
- 'DEFAULT_AUTHENTICATION_CLASSES': (
- 'rest_framework.authentication.SessionAuthentication',
- 'rest_framework.authentication.BasicAuthentication'
+ "DEFAULT_AUTHENTICATION_CLASSES": (
+ "rest_framework.authentication.SessionAuthentication",
+ "rest_framework.authentication.BasicAuthentication",
),
- 'DEFAULT_PERMISSION_CLASSES': (
- 'rest_framework.permissions.AllowAny',
- ),
- 'DEFAULT_THROTTLE_CLASSES': (),
- 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
- 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
- 'DEFAULT_VERSIONING_CLASS': None,
-
+ "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",),
+ "DEFAULT_THROTTLE_CLASSES": (),
+ "DEFAULT_CONTENT_NEGOTIATION_CLASS": "rest_framework.negotiation.DefaultContentNegotiation",
+ "DEFAULT_METADATA_CLASS": "rest_framework.metadata.SimpleMetadata",
+ "DEFAULT_VERSIONING_CLASS": None,
# Generic view behavior
- 'DEFAULT_PAGINATION_CLASS': None,
- 'DEFAULT_FILTER_BACKENDS': (),
-
+ "DEFAULT_PAGINATION_CLASS": None,
+ "DEFAULT_FILTER_BACKENDS": (),
# Schema
- 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
-
+ "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.AutoSchema",
# Throttling
- 'DEFAULT_THROTTLE_RATES': {
- 'user': None,
- 'anon': None,
- },
- 'NUM_PROXIES': None,
-
+ "DEFAULT_THROTTLE_RATES": {"user": None, "anon": None},
+ "NUM_PROXIES": None,
# Pagination
- 'PAGE_SIZE': None,
-
+ "PAGE_SIZE": None,
# Filtering
- 'SEARCH_PARAM': 'search',
- 'ORDERING_PARAM': 'ordering',
-
+ "SEARCH_PARAM": "search",
+ "ORDERING_PARAM": "ordering",
# Versioning
- 'DEFAULT_VERSION': None,
- 'ALLOWED_VERSIONS': None,
- 'VERSION_PARAM': 'version',
-
+ "DEFAULT_VERSION": None,
+ "ALLOWED_VERSIONS": None,
+ "VERSION_PARAM": "version",
# Authentication
- 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
- 'UNAUTHENTICATED_TOKEN': None,
-
+ "UNAUTHENTICATED_USER": "django.contrib.auth.models.AnonymousUser",
+ "UNAUTHENTICATED_TOKEN": None,
# View configuration
- 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
- 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
-
+ "VIEW_NAME_FUNCTION": "rest_framework.views.get_view_name",
+ "VIEW_DESCRIPTION_FUNCTION": "rest_framework.views.get_view_description",
# Exception handling
- 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
- 'NON_FIELD_ERRORS_KEY': 'non_field_errors',
-
+ "EXCEPTION_HANDLER": "rest_framework.views.exception_handler",
+ "NON_FIELD_ERRORS_KEY": "non_field_errors",
# Testing
- 'TEST_REQUEST_RENDERER_CLASSES': (
- 'rest_framework.renderers.MultiPartRenderer',
- 'rest_framework.renderers.JSONRenderer'
+ "TEST_REQUEST_RENDERER_CLASSES": (
+ "rest_framework.renderers.MultiPartRenderer",
+ "rest_framework.renderers.JSONRenderer",
),
- 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
-
+ "TEST_REQUEST_DEFAULT_FORMAT": "multipart",
# Hyperlink settings
- 'URL_FORMAT_OVERRIDE': 'format',
- 'FORMAT_SUFFIX_KWARG': 'format',
- 'URL_FIELD_NAME': 'url',
-
+ "URL_FORMAT_OVERRIDE": "format",
+ "FORMAT_SUFFIX_KWARG": "format",
+ "URL_FIELD_NAME": "url",
# Input and output formats
- 'DATE_FORMAT': ISO_8601,
- 'DATE_INPUT_FORMATS': (ISO_8601,),
-
- 'DATETIME_FORMAT': ISO_8601,
- 'DATETIME_INPUT_FORMATS': (ISO_8601,),
-
- 'TIME_FORMAT': ISO_8601,
- 'TIME_INPUT_FORMATS': (ISO_8601,),
-
+ "DATE_FORMAT": ISO_8601,
+ "DATE_INPUT_FORMATS": (ISO_8601,),
+ "DATETIME_FORMAT": ISO_8601,
+ "DATETIME_INPUT_FORMATS": (ISO_8601,),
+ "TIME_FORMAT": ISO_8601,
+ "TIME_INPUT_FORMATS": (ISO_8601,),
# Encoding
- 'UNICODE_JSON': True,
- 'COMPACT_JSON': True,
- 'STRICT_JSON': True,
- 'COERCE_DECIMAL_TO_STRING': True,
- 'UPLOADED_FILES_USE_URL': True,
-
+ "UNICODE_JSON": True,
+ "COMPACT_JSON": True,
+ "STRICT_JSON": True,
+ "COERCE_DECIMAL_TO_STRING": True,
+ "UPLOADED_FILES_USE_URL": True,
# Browseable API
- 'HTML_SELECT_CUTOFF': 1000,
- 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
-
+ "HTML_SELECT_CUTOFF": 1000,
+ "HTML_SELECT_CUTOFF_TEXT": "More than {count} items...",
# Schemas
- 'SCHEMA_COERCE_PATH_PK': True,
- 'SCHEMA_COERCE_METHOD_NAMES': {
- 'retrieve': 'read',
- 'destroy': 'delete'
- },
+ "SCHEMA_COERCE_PATH_PK": True,
+ "SCHEMA_COERCE_METHOD_NAMES": {"retrieve": "read", "destroy": "delete"},
}
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
- 'DEFAULT_RENDERER_CLASSES',
- 'DEFAULT_PARSER_CLASSES',
- 'DEFAULT_AUTHENTICATION_CLASSES',
- 'DEFAULT_PERMISSION_CLASSES',
- 'DEFAULT_THROTTLE_CLASSES',
- 'DEFAULT_CONTENT_NEGOTIATION_CLASS',
- 'DEFAULT_METADATA_CLASS',
- 'DEFAULT_VERSIONING_CLASS',
- 'DEFAULT_PAGINATION_CLASS',
- 'DEFAULT_FILTER_BACKENDS',
- 'DEFAULT_SCHEMA_CLASS',
- 'EXCEPTION_HANDLER',
- 'TEST_REQUEST_RENDERER_CLASSES',
- 'UNAUTHENTICATED_USER',
- 'UNAUTHENTICATED_TOKEN',
- 'VIEW_NAME_FUNCTION',
- 'VIEW_DESCRIPTION_FUNCTION'
+ "DEFAULT_RENDERER_CLASSES",
+ "DEFAULT_PARSER_CLASSES",
+ "DEFAULT_AUTHENTICATION_CLASSES",
+ "DEFAULT_PERMISSION_CLASSES",
+ "DEFAULT_THROTTLE_CLASSES",
+ "DEFAULT_CONTENT_NEGOTIATION_CLASS",
+ "DEFAULT_METADATA_CLASS",
+ "DEFAULT_VERSIONING_CLASS",
+ "DEFAULT_PAGINATION_CLASS",
+ "DEFAULT_FILTER_BACKENDS",
+ "DEFAULT_SCHEMA_CLASS",
+ "EXCEPTION_HANDLER",
+ "TEST_REQUEST_RENDERER_CLASSES",
+ "UNAUTHENTICATED_USER",
+ "UNAUTHENTICATED_TOKEN",
+ "VIEW_NAME_FUNCTION",
+ "VIEW_DESCRIPTION_FUNCTION",
)
# List of settings that have been removed
-REMOVED_SETTINGS = (
- "PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY",
-)
+REMOVED_SETTINGS = ("PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY")
def perform_import(val, setting_name):
@@ -179,11 +153,16 @@ def import_from_string(val, setting_name):
"""
try:
# Nod to tastypie's use of importlib.
- module_path, class_name = val.rsplit('.', 1)
+ module_path, class_name = val.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
- msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
+ msg = "Could not import '%s' for API setting '%s'. %s: %s." % (
+ val,
+ setting_name,
+ e.__class__.__name__,
+ e,
+ )
raise ImportError(msg)
@@ -198,6 +177,7 @@ class APISettings(object):
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
"""
+
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = self.__check_user_settings(user_settings)
@@ -207,8 +187,8 @@ class APISettings(object):
@property
def user_settings(self):
- if not hasattr(self, '_user_settings'):
- self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
+ if not hasattr(self, "_user_settings"):
+ self._user_settings = getattr(settings, "REST_FRAMEWORK", {})
return self._user_settings
def __getattr__(self, attr):
@@ -235,23 +215,26 @@ class APISettings(object):
SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/"
for setting in REMOVED_SETTINGS:
if setting in user_settings:
- raise RuntimeError("The '%s' setting has been removed. Please refer to '%s' for available settings." % (setting, SETTINGS_DOC))
+ raise RuntimeError(
+ "The '%s' setting has been removed. Please refer to '%s' for available settings."
+ % (setting, SETTINGS_DOC)
+ )
return user_settings
def reload(self):
for attr in self._cached_attrs:
delattr(self, attr)
self._cached_attrs.clear()
- if hasattr(self, '_user_settings'):
- delattr(self, '_user_settings')
+ if hasattr(self, "_user_settings"):
+ delattr(self, "_user_settings")
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs):
- setting = kwargs['setting']
- if setting == 'REST_FRAMEWORK':
+ setting = kwargs["setting"]
+ if setting == "REST_FRAMEWORK":
api_settings.reload()
diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py
index f48675d5e..b71599a5d 100644
--- a/rest_framework/templatetags/rest_framework.py
+++ b/rest_framework/templatetags/rest_framework.py
@@ -15,22 +15,23 @@ from rest_framework.compat import apply_markdown, pygments_highlight
from rest_framework.renderers import HTMLFormRenderer
from rest_framework.utils.urls import replace_query_param
+
register = template.Library()
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
-@register.tag(name='code')
+@register.tag(name="code")
def highlight_code(parser, token):
code = token.split_contents()[-1]
- nodelist = parser.parse(('endcode',))
+ nodelist = parser.parse(("endcode",))
parser.delete_first_token()
return CodeNode(code, nodelist)
class CodeNode(template.Node):
- style = 'emacs'
+ style = "emacs"
def __init__(self, lang, code):
self.lang = lang
@@ -43,24 +44,17 @@ class CodeNode(template.Node):
@register.filter()
def with_location(fields, location):
- return [
- field for field in fields
- if field.location == location
- ]
+ return [field for field in fields if field.location == location]
@register.simple_tag
def form_for_link(link):
import coreschema
- properties = OrderedDict([
- (field.name, field.schema or coreschema.String())
- for field in link.fields
- ])
- required = [
- field.name
- for field in link.fields
- if field.required
- ]
+
+ properties = OrderedDict(
+ [(field.name, field.schema or coreschema.String()) for field in link.fields]
+ )
+ required = [field.name for field in link.fields if field.required]
schema = coreschema.Object(properties=properties, required=required)
return mark_safe(coreschema.render_to_form(schema))
@@ -79,14 +73,14 @@ def get_pagination_html(pager):
@register.simple_tag
def render_form(serializer, template_pack=None):
- style = {'template_pack': template_pack} if template_pack else {}
+ style = {"template_pack": template_pack} if template_pack else {}
renderer = HTMLFormRenderer()
- return renderer.render(serializer.data, None, {'style': style})
+ return renderer.render(serializer.data, None, {"style": style})
@register.simple_tag
def render_field(field, style):
- renderer = style.get('renderer', HTMLFormRenderer())
+ renderer = style.get("renderer", HTMLFormRenderer())
return renderer.render_field(field, style)
@@ -96,9 +90,9 @@ def optional_login(request):
Include a login snippet if REST framework's login view is in the URLconf.
"""
try:
- login_url = reverse('rest_framework:login')
+ login_url = reverse("rest_framework:login")
except NoReverseMatch:
- return ''
+ return ""
snippet = "Log in"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
@@ -112,9 +106,9 @@ def optional_docs_login(request):
Include a login snippet if REST framework's login view is in the URLconf.
"""
try:
- login_url = reverse('rest_framework:login')
+ login_url = reverse("rest_framework:login")
except NoReverseMatch:
- return 'log in'
+ return "log in"
snippet = "log in"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
@@ -128,7 +122,7 @@ def optional_logout(request, user):
Include a logout snippet if REST framework's logout view is in the URLconf.
"""
try:
- logout_url = reverse('rest_framework:logout')
+ logout_url = reverse("rest_framework:logout")
except NoReverseMatch:
snippet = format_html('{user}', user=escape(user))
return mark_safe(snippet)
@@ -142,7 +136,9 @@ def optional_logout(request, user):
Log out
"""
- snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path))
+ snippet = format_html(
+ snippet, user=escape(user), href=logout_url, next=escape(request.path)
+ )
return mark_safe(snippet)
@@ -160,16 +156,13 @@ def add_query_param(request, key, val):
@register.filter
def as_string(value):
if value is None:
- return ''
- return '%s' % value
+ return ""
+ return "%s" % value
@register.filter
def as_list_of_strings(value):
- return [
- '' if (item is None) else ('%s' % item)
- for item in value
- ]
+ return ["" if (item is None) else ("%s" % item) for item in value]
@register.filter
@@ -190,45 +183,52 @@ def add_class(value, css_class):
html = six.text_type(value)
match = class_re.search(html)
if match:
- m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
- css_class, css_class),
- match.group(1))
+ m = re.search(
+ r"^%s$|^%s\s|\s%s\s|\s%s$" % (css_class, css_class, css_class, css_class),
+ match.group(1),
+ )
if not m:
- return mark_safe(class_re.sub(match.group(1) + " " + css_class,
- html))
+ return mark_safe(class_re.sub(match.group(1) + " " + css_class, html))
else:
- return mark_safe(html.replace('>', ' class="%s">' % css_class, 1))
+ return mark_safe(html.replace(">", ' class="%s">' % css_class, 1))
return value
@register.filter
def format_value(value):
- if getattr(value, 'is_hyperlink', False):
+ if getattr(value, "is_hyperlink", False):
name = six.text_type(value.obj)
- return mark_safe('%s' % (value, escape(name)))
+ return mark_safe("%s" % (value, escape(name)))
if value is None or isinstance(value, bool):
- return mark_safe('%s
' % {True: 'true', False: 'false', None: 'null'}[value])
+ return mark_safe(
+ "%s
" % {True: "true", False: "false", None: "null"}[value]
+ )
elif isinstance(value, list):
if any([isinstance(item, (list, dict)) for item in value]):
- template = loader.get_template('rest_framework/admin/list_value.html')
+ template = loader.get_template("rest_framework/admin/list_value.html")
else:
- template = loader.get_template('rest_framework/admin/simple_list_value.html')
- context = {'value': value}
+ template = loader.get_template(
+ "rest_framework/admin/simple_list_value.html"
+ )
+ context = {"value": value}
return template.render(context)
elif isinstance(value, dict):
- template = loader.get_template('rest_framework/admin/dict_value.html')
- context = {'value': value}
+ template = loader.get_template("rest_framework/admin/dict_value.html")
+ context = {"value": value}
return template.render(context)
elif isinstance(value, six.string_types):
- if (
- (value.startswith('http:') or value.startswith('https:')) and not
- re.search(r'\s', value)
+ if (value.startswith("http:") or value.startswith("https:")) and not re.search(
+ r"\s", value
):
- return mark_safe('{value}'.format(value=escape(value)))
- elif '@' in value and not re.search(r'\s', value):
- return mark_safe('{value}'.format(value=escape(value)))
- elif '\n' in value:
- return mark_safe('%s
' % escape(value))
+ return mark_safe(
+ '{value}'.format(value=escape(value))
+ )
+ elif "@" in value and not re.search(r"\s", value):
+ return mark_safe(
+ '{value}'.format(value=escape(value))
+ )
+ elif "\n" in value:
+ return mark_safe("%s
" % escape(value))
return six.text_type(value)
@@ -266,7 +266,7 @@ def schema_links(section, sec_key=None):
"""
Recursively find every link in a schema, even nested.
"""
- NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys
+ NESTED_FORMAT = "%s > %s" # this format is used in docs/js/api.js:normalizeKeys
links = section.links
if section.data:
data = section.data.items()
@@ -287,20 +287,30 @@ def schema_links(section, sec_key=None):
@register.filter
def add_nested_class(value):
if isinstance(value, dict):
- return 'class=nested'
- if isinstance(value, list) and any([isinstance(item, (list, dict)) for item in value]):
- return 'class=nested'
- return ''
+ return "class=nested"
+ if isinstance(value, list) and any(
+ [isinstance(item, (list, dict)) for item in value]
+ ):
+ return "class=nested"
+ return ""
# Bunch of stuff cloned from urlize
-TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
-WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'),
- ('"', '"'), ("'", "'")]
-word_split_re = re.compile(r'(\s+)')
-simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
-simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
-simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
+TRAILING_PUNCTUATION = [".", ",", ":", ";", ".)", '"', "']", "'}", "'"]
+WRAPPING_PUNCTUATION = [
+ ("(", ")"),
+ ("<", ">"),
+ ("[", "]"),
+ ("<", ">"),
+ ('"', '"'),
+ ("'", "'"),
+]
+word_split_re = re.compile(r"(\s+)")
+simple_url_re = re.compile(r"^https?://\[?\w", re.IGNORECASE)
+simple_url_2_re = re.compile(
+ r"^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$", re.IGNORECASE
+)
+simple_email_re = re.compile(r"^\S+@\S+\.\S+$")
def smart_urlquote_wrapper(matched_url):
@@ -332,8 +342,13 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
If autoescape is True, the link text and URLs will get autoescaped.
"""
+
def trim_url(x, limit=trim_url_limit):
- return limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
+ return (
+ limit is not None
+ and (len(x) > limit and ("%s..." % x[: max(0, limit - 3)]))
+ or x
+ )
safe_input = isinstance(text, SafeData)
@@ -344,40 +359,40 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words = word_split_re.split(force_text(text))
for i, word in enumerate(words):
- if '.' in word or '@' in word or ':' in word:
+ if "." in word or "@" in word or ":" in word:
# Deal with punctuation.
- lead, middle, trail = '', word, ''
+ lead, middle, trail = "", word, ""
for punctuation in TRAILING_PUNCTUATION:
if middle.endswith(punctuation):
- middle = middle[:-len(punctuation)]
+ middle = middle[: -len(punctuation)]
trail = punctuation + trail
for opening, closing in WRAPPING_PUNCTUATION:
if middle.startswith(opening):
- middle = middle[len(opening):]
+ middle = middle[len(opening) :]
lead = lead + opening
# Keep parentheses at the end only if they're balanced.
if (
- middle.endswith(closing) and
- middle.count(closing) == middle.count(opening) + 1
+ middle.endswith(closing)
+ and middle.count(closing) == middle.count(opening) + 1
):
- middle = middle[:-len(closing)]
+ middle = middle[: -len(closing)]
trail = closing + trail
# Make URL we want to point to.
url = None
- nofollow_attr = ' rel="nofollow"' if nofollow else ''
+ nofollow_attr = ' rel="nofollow"' if nofollow else ""
if simple_url_re.match(middle):
url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle):
- url = smart_urlquote_wrapper('http://%s' % middle)
- elif ':' not in middle and simple_email_re.match(middle):
- local, domain = middle.rsplit('@', 1)
+ url = smart_urlquote_wrapper("http://%s" % middle)
+ elif ":" not in middle and simple_email_re.match(middle):
+ local, domain = middle.rsplit("@", 1)
try:
- domain = domain.encode('idna').decode('ascii')
+ domain = domain.encode("idna").decode("ascii")
except UnicodeError:
continue
- url = 'mailto:%s@%s' % (local, domain)
- nofollow_attr = ''
+ url = "mailto:%s@%s" % (local, domain)
+ nofollow_attr = ""
# Make link.
if url:
@@ -385,12 +400,12 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
lead, trail = conditional_escape(lead), conditional_escape(trail)
url, trimmed = conditional_escape(url), conditional_escape(trimmed)
middle = '%s' % (url, nofollow_attr, trimmed)
- words[i] = '%s%s%s' % (lead, middle, trail)
+ words[i] = "%s%s%s" % (lead, middle, trail)
else:
words[i] = conditional_escape(word)
else:
words[i] = conditional_escape(word)
- return mark_safe(''.join(words))
+ return mark_safe("".join(words))
@register.filter
@@ -399,6 +414,6 @@ def break_long_headers(header):
Breaks headers longer than 160 characters (~page length)
when possible (are comma separated)
"""
- if len(header) > 160 and ',' in header:
- header = mark_safe('
' + ',
'.join(header.split(',')))
+ if len(header) > 160 and "," in header:
+ header = mark_safe("
" + ",
".join(header.split(",")))
return header
diff --git a/rest_framework/test.py b/rest_framework/test.py
index edacf0066..0dd6944ec 100644
--- a/rest_framework/test.py
+++ b/rest_framework/test.py
@@ -11,9 +11,11 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.test import override_settings, testcases
-from django.test.client import Client as DjangoClient
-from django.test.client import ClientHandler
-from django.test.client import RequestFactory as DjangoRequestFactory
+from django.test.client import (
+ Client as DjangoClient,
+ ClientHandler,
+ RequestFactory as DjangoRequestFactory,
+)
from django.utils import six
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
@@ -28,6 +30,7 @@ def force_authenticate(request, user=None, token=None):
if requests is not None:
+
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
@@ -48,6 +51,7 @@ if requests is not None:
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
+
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
@@ -62,19 +66,19 @@ if requests is not None:
# Set request content, if any exists.
if request.body is not None:
- if hasattr(request.body, 'read'):
- kwargs['data'] = request.body.read()
+ if hasattr(request.body, "read"):
+ kwargs["data"] = request.body.read()
else:
- kwargs['data'] = request.body
- if 'content-type' in request.headers:
- kwargs['content_type'] = request.headers['content-type']
+ kwargs["data"] = request.body
+ if "content-type" in request.headers:
+ kwargs["content_type"] = request.headers["content-type"]
# Set request headers.
for key, value in request.headers.items():
key = key.upper()
- if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
+ if key in ("CONNECTION", "CONTENT-LENGTH", "CONTENT-TYPE"):
continue
- kwargs['HTTP_%s' % key.replace('-', '_')] = value
+ kwargs["HTTP_%s" % key.replace("-", "_")] = value
return self.factory.generic(method, url, **kwargs).environ
@@ -85,20 +89,20 @@ if requests is not None:
raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers):
- status, _, reason = wsgi_status.partition(' ')
- raw_kwargs['status'] = int(status)
- raw_kwargs['reason'] = reason
- raw_kwargs['headers'] = wsgi_headers
- raw_kwargs['version'] = 11
- raw_kwargs['preload_content'] = False
- raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
+ status, _, reason = wsgi_status.partition(" ")
+ raw_kwargs["status"] = int(status)
+ raw_kwargs["reason"] = reason
+ raw_kwargs["headers"] = wsgi_headers
+ raw_kwargs["version"] = 11
+ raw_kwargs["preload_content"] = False
+ raw_kwargs["original_response"] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse
- raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
+ raw_kwargs["body"] = io.BytesIO(b"".join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response
@@ -111,33 +115,47 @@ if requests is not None:
def __init__(self, *args, **kwargs):
super(RequestsClient, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
- self.mount('http://', adapter)
- self.mount('https://', adapter)
+ self.mount("http://", adapter)
+ self.mount("https://", adapter)
def request(self, method, url, *args, **kwargs):
- if not url.startswith('http'):
- raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
+ if not url.startswith("http"):
+ raise ValueError(
+ 'Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"'
+ % url
+ )
return super(RequestsClient, self).request(method, url, *args, **kwargs)
+
else:
+
def RequestsClient(*args, **kwargs):
- raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
+ raise ImproperlyConfigured(
+ "requests must be installed in order to use RequestsClient."
+ )
if coreapi is not None:
+
class CoreAPIClient(coreapi.Client):
def __init__(self, *args, **kwargs):
self._session = RequestsClient()
- kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
+ kwargs["transports"] = [
+ coreapi.transports.HTTPTransport(session=self.session)
+ ]
return super(CoreAPIClient, self).__init__(*args, **kwargs)
@property
def session(self):
return self._session
+
else:
+
def CoreAPIClient(*args, **kwargs):
- raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
+ raise ImproperlyConfigured(
+ "coreapi must be installed in order to use CoreAPIClient."
+ )
class APIRequestFactory(DjangoRequestFactory):
@@ -157,11 +175,11 @@ class APIRequestFactory(DjangoRequestFactory):
"""
if data is None:
- return ('', content_type)
+ return ("", content_type)
- assert format is None or content_type is None, (
- 'You may not set both `format` and `content_type`.'
- )
+ assert (
+ format is None or content_type is None
+ ), "You may not set both `format` and `content_type`."
if content_type:
# Content type specified explicitly, treat data as a raw bytestring
@@ -175,7 +193,7 @@ class APIRequestFactory(DjangoRequestFactory):
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format,
- ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
+ ", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]),
)
)
@@ -195,47 +213,53 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type
def get(self, path, data=None, **extra):
- r = {
- 'QUERY_STRING': urlencode(data or {}, doseq=True),
- }
- if not data and '?' in path:
+ r = {"QUERY_STRING": urlencode(data or {}, doseq=True)}
+ if not data and "?" in path:
# Fix to support old behavior where you have the arguments in the
# url. See #1461.
- query_string = force_bytes(path.split('?')[1])
+ query_string = force_bytes(path.split("?")[1])
if six.PY3:
- query_string = query_string.decode('iso-8859-1')
- r['QUERY_STRING'] = query_string
+ query_string = query_string.decode("iso-8859-1")
+ r["QUERY_STRING"] = query_string
r.update(extra)
- return self.generic('GET', path, **r)
+ return self.generic("GET", path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
- return self.generic('POST', path, data, content_type, **extra)
+ return self.generic("POST", path, data, content_type, **extra)
def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
- return self.generic('PUT', path, data, content_type, **extra)
+ return self.generic("PUT", path, data, content_type, **extra)
def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
- return self.generic('PATCH', path, data, content_type, **extra)
+ return self.generic("PATCH", path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
- return self.generic('DELETE', path, data, content_type, **extra)
+ return self.generic("DELETE", path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
- return self.generic('OPTIONS', path, data, content_type, **extra)
+ return self.generic("OPTIONS", path, data, content_type, **extra)
- def generic(self, method, path, data='',
- content_type='application/octet-stream', secure=False, **extra):
+ def generic(
+ self,
+ method,
+ path,
+ data="",
+ content_type="application/octet-stream",
+ secure=False,
+ **extra
+ ):
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
if content_type is not None:
- extra['CONTENT_TYPE'] = str(content_type)
+ extra["CONTENT_TYPE"] = str(content_type)
return super(APIRequestFactory, self).generic(
- method, path, data, content_type, secure, **extra)
+ method, path, data, content_type, secure, **extra
+ )
def request(self, **kwargs):
request = super(APIRequestFactory, self).request(**kwargs)
@@ -294,42 +318,52 @@ class APIClient(APIRequestFactory, DjangoClient):
response = self._handle_redirects(response, **extra)
return response
- def post(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
+ def post(
+ self, path, data=None, format=None, content_type=None, follow=False, **extra
+ ):
response = super(APIClient, self).post(
- path, data=data, format=format, content_type=content_type, **extra)
+ path, data=data, format=format, content_type=content_type, **extra
+ )
if follow:
response = self._handle_redirects(response, **extra)
return response
- def put(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
+ def put(
+ self, path, data=None, format=None, content_type=None, follow=False, **extra
+ ):
response = super(APIClient, self).put(
- path, data=data, format=format, content_type=content_type, **extra)
+ path, data=data, format=format, content_type=content_type, **extra
+ )
if follow:
response = self._handle_redirects(response, **extra)
return response
- def patch(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
+ def patch(
+ self, path, data=None, format=None, content_type=None, follow=False, **extra
+ ):
response = super(APIClient, self).patch(
- path, data=data, format=format, content_type=content_type, **extra)
+ path, data=data, format=format, content_type=content_type, **extra
+ )
if follow:
response = self._handle_redirects(response, **extra)
return response
- def delete(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
+ def delete(
+ self, path, data=None, format=None, content_type=None, follow=False, **extra
+ ):
response = super(APIClient, self).delete(
- path, data=data, format=format, content_type=content_type, **extra)
+ path, data=data, format=format, content_type=content_type, **extra
+ )
if follow:
response = self._handle_redirects(response, **extra)
return response
- def options(self, path, data=None, format=None, content_type=None,
- follow=False, **extra):
+ def options(
+ self, path, data=None, format=None, content_type=None, follow=False, **extra
+ ):
response = super(APIClient, self).options(
- path, data=data, format=format, content_type=content_type, **extra)
+ path, data=data, format=format, content_type=content_type, **extra
+ )
if follow:
response = self._handle_redirects(response, **extra)
return response
@@ -377,13 +411,14 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
def test_something_else(self):
...
"""
+
@classmethod
def setUpClass(cls):
# Get the module of the TestCase subclass
cls._module = import_module(cls.__module__)
cls._override = override_settings(ROOT_URLCONF=cls.__module__)
- if hasattr(cls._module, 'urlpatterns'):
+ if hasattr(cls._module, "urlpatterns"):
cls._module_urlpatterns = cls._module.urlpatterns
cls._module.urlpatterns = cls.urlpatterns
@@ -396,7 +431,7 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
super(URLPatternsTestCase, cls).tearDownClass()
cls._override.disable()
- if hasattr(cls, '_module_urlpatterns'):
+ if hasattr(cls, "_module_urlpatterns"):
cls._module.urlpatterns = cls._module_urlpatterns
else:
del cls._module.urlpatterns
diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py
index 834ced148..0f7f741fd 100644
--- a/rest_framework/throttling.py
+++ b/rest_framework/throttling.py
@@ -20,7 +20,7 @@ class BaseThrottle(object):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
- raise NotImplementedError('.allow_request() must be overridden')
+ raise NotImplementedError(".allow_request() must be overridden")
def get_ident(self, request):
"""
@@ -28,18 +28,18 @@ class BaseThrottle(object):
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
- xff = request.META.get('HTTP_X_FORWARDED_FOR')
- remote_addr = request.META.get('REMOTE_ADDR')
+ xff = request.META.get("HTTP_X_FORWARDED_FOR")
+ remote_addr = request.META.get("REMOTE_ADDR")
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
- addrs = xff.split(',')
+ addrs = xff.split(",")
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
- return ''.join(xff.split()) if xff else remote_addr
+ return "".join(xff.split()) if xff else remote_addr
def wait(self):
"""
@@ -61,14 +61,15 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
+
cache = default_cache
timer = time.time
- cache_format = 'throttle_%(scope)s_%(ident)s'
+ cache_format = "throttle_%(scope)s_%(ident)s"
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
- if not getattr(self, 'rate', None):
+ if not getattr(self, "rate", None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
@@ -79,15 +80,17 @@ class SimpleRateThrottle(BaseThrottle):
May return `None` if the request should not be throttled.
"""
- raise NotImplementedError('.get_cache_key() must be overridden')
+ raise NotImplementedError(".get_cache_key() must be overridden")
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
- if not getattr(self, 'scope', None):
- msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
- self.__class__.__name__)
+ if not getattr(self, "scope", None):
+ msg = (
+ "You must set either `.scope` or `.rate` for '%s' throttle"
+ % self.__class__.__name__
+ )
raise ImproperlyConfigured(msg)
try:
@@ -103,9 +106,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
if rate is None:
return (None, None)
- num, period = rate.split('/')
+ num, period = rate.split("/")
num_requests = int(num)
- duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
+ duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
@@ -170,15 +173,16 @@ class AnonRateThrottle(SimpleRateThrottle):
The IP address of the request will be used as the unique cache key.
"""
- scope = 'anon'
+
+ scope = "anon"
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None # Only throttle unauthenticated requests.
return self.cache_format % {
- 'scope': self.scope,
- 'ident': self.get_ident(request)
+ "scope": self.scope,
+ "ident": self.get_ident(request),
}
@@ -190,7 +194,8 @@ class UserRateThrottle(SimpleRateThrottle):
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
- scope = 'user'
+
+ scope = "user"
def get_cache_key(self, request, view):
if request.user.is_authenticated:
@@ -198,10 +203,7 @@ class UserRateThrottle(SimpleRateThrottle):
else:
ident = self.get_ident(request)
- return self.cache_format % {
- 'scope': self.scope,
- 'ident': ident
- }
+ return self.cache_format % {"scope": self.scope, "ident": ident}
class ScopedRateThrottle(SimpleRateThrottle):
@@ -211,7 +213,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed.
"""
- scope_attr = 'throttle_scope'
+
+ scope_attr = "throttle_scope"
def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
@@ -246,7 +249,4 @@ class ScopedRateThrottle(SimpleRateThrottle):
else:
ident = self.get_ident(request)
- return self.cache_format % {
- 'scope': self.scope,
- 'ident': ident
- }
+ return self.cache_format % {"scope": self.scope, "ident": ident}
diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py
index ab3a74978..076f03b38 100644
--- a/rest_framework/urlpatterns.py
+++ b/rest_framework/urlpatterns.py
@@ -3,7 +3,11 @@ from __future__ import unicode_literals
from django.conf.urls import include, url
from rest_framework.compat import (
- URLResolver, get_regex_pattern, is_route_pattern, path, register_converter
+ URLResolver,
+ get_regex_pattern,
+ is_route_pattern,
+ path,
+ register_converter,
)
from rest_framework.settings import api_settings
@@ -13,7 +17,7 @@ def _get_format_path_converter(suffix_kwarg, allowed):
if len(allowed) == 1:
allowed_pattern = allowed[0]
else:
- allowed_pattern = '(?:%s)' % '|'.join(allowed)
+ allowed_pattern = "(?:%s)" % "|".join(allowed)
suffix_pattern = r"\.%s/?" % allowed_pattern
else:
suffix_pattern = r"\.[a-z0-9]+/?"
@@ -22,19 +26,21 @@ def _get_format_path_converter(suffix_kwarg, allowed):
regex = suffix_pattern
def to_python(self, value):
- return value.strip('./')
+ return value.strip("./")
def to_url(self, value):
- return '.' + value + '/'
+ return "." + value + "/"
- converter_name = 'drf_format_suffix'
+ converter_name = "drf_format_suffix"
if allowed:
- converter_name += '_' + '_'.join(allowed)
+ converter_name += "_" + "_".join(allowed)
return converter_name, FormatSuffixConverter
-def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route=None):
+def apply_suffix_patterns(
+ urlpatterns, suffix_pattern, suffix_required, suffix_route=None
+):
ret = []
for urlpattern in urlpatterns:
if isinstance(urlpattern, URLResolver):
@@ -44,23 +50,28 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
app_name = urlpattern.app_name
kwargs = urlpattern.default_kwargs
# Add in the included patterns, after applying the suffixes
- patterns = apply_suffix_patterns(urlpattern.url_patterns,
- suffix_pattern,
- suffix_required,
- suffix_route)
+ patterns = apply_suffix_patterns(
+ urlpattern.url_patterns, suffix_pattern, suffix_required, suffix_route
+ )
# if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern):
assert path is not None
route = str(urlpattern.pattern)
- new_pattern = path(route, include((patterns, app_name), namespace), kwargs)
+ new_pattern = path(
+ route, include((patterns, app_name), namespace), kwargs
+ )
else:
- new_pattern = url(regex, include((patterns, app_name), namespace), kwargs)
+ new_pattern = url(
+ regex, include((patterns, app_name), namespace), kwargs
+ )
ret.append(new_pattern)
else:
# Regular URL pattern
- regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern
+ regex = (
+ get_regex_pattern(urlpattern).rstrip("$").rstrip("/") + suffix_pattern
+ )
view = urlpattern.callback
kwargs = urlpattern.default_args
name = urlpattern.name
@@ -72,7 +83,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
if is_route_pattern(urlpattern):
assert path is not None
assert suffix_route is not None
- route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route
+ route = str(urlpattern.pattern).rstrip("$").rstrip("/") + suffix_route
new_pattern = path(route, view, kwargs, name)
else:
new_pattern = url(regex, view, kwargs, name)
@@ -103,17 +114,21 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
if len(allowed) == 1:
allowed_pattern = allowed[0]
else:
- allowed_pattern = '(%s)' % '|'.join(allowed)
- suffix_pattern = r'\.(?P<%s>%s)/?$' % (suffix_kwarg, allowed_pattern)
+ allowed_pattern = "(%s)" % "|".join(allowed)
+ suffix_pattern = r"\.(?P<%s>%s)/?$" % (suffix_kwarg, allowed_pattern)
else:
- suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg
+ suffix_pattern = r"\.(?P<%s>[a-z0-9]+)/?$" % suffix_kwarg
if path and register_converter:
- converter_name, suffix_converter = _get_format_path_converter(suffix_kwarg, allowed)
+ converter_name, suffix_converter = _get_format_path_converter(
+ suffix_kwarg, allowed
+ )
register_converter(suffix_converter, converter_name)
- suffix_route = '<%s:%s>' % (converter_name, suffix_kwarg)
+ suffix_route = "<%s:%s>" % (converter_name, suffix_kwarg)
else:
suffix_route = None
- return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route)
+ return apply_suffix_patterns(
+ urlpatterns, suffix_pattern, suffix_required, suffix_route
+ )
diff --git a/rest_framework/urls.py b/rest_framework/urls.py
index 0e4c2661b..6932e5e27 100644
--- a/rest_framework/urls.py
+++ b/rest_framework/urls.py
@@ -16,8 +16,13 @@ from __future__ import unicode_literals
from django.conf.urls import url
from django.contrib.auth import views
-app_name = 'rest_framework'
+
+app_name = "rest_framework"
urlpatterns = [
- url(r'^login/$', views.LoginView.as_view(template_name='rest_framework/login.html'), name='login'),
- url(r'^logout/$', views.LogoutView.as_view(), name='logout'),
+ url(
+ r"^login/$",
+ views.LoginView.as_view(template_name="rest_framework/login.html"),
+ name="login",
+ ),
+ url(r"^logout/$", views.LogoutView.as_view(), name="logout"),
]
diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py
index e0374ffd0..b3626a52e 100644
--- a/rest_framework/utils/breadcrumbs.py
+++ b/rest_framework/utils/breadcrumbs.py
@@ -23,8 +23,8 @@ def get_breadcrumbs(url, request=None):
else:
# Check if this is a REST framework view,
# and if so add it to the breadcrumbs
- cls = getattr(view, 'cls', None)
- initkwargs = getattr(view, 'initkwargs', {})
+ cls = getattr(view, "cls", None)
+ initkwargs = getattr(view, "initkwargs", {})
if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
@@ -35,21 +35,21 @@ def get_breadcrumbs(url, request=None):
breadcrumbs_list.insert(0, (name, insert_url))
seen.append(view)
- if url == '':
+ if url == "":
# All done
return breadcrumbs_list
- elif url.endswith('/'):
+ elif url.endswith("/"):
# Drop trailing slash off the end and continue to try to
# resolve more breadcrumbs
- url = url.rstrip('/')
+ url = url.rstrip("/")
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to
# resolve more breadcrumbs
- url = url[:url.rfind('/') + 1]
+ url = url[: url.rfind("/") + 1]
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
- prefix = get_script_prefix().rstrip('/')
- url = url[len(prefix):]
+ prefix = get_script_prefix().rstrip("/")
+ url = url[len(prefix) :]
return breadcrumbs_recursive(url, [], prefix, [])
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index d8f4aeb4e..66ecfb133 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -21,6 +21,7 @@ class JSONEncoder(json.JSONEncoder):
JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, generators and other basic python objects.
"""
+
def default(self, obj):
# For Date Time string spec, see ECMA 262
# https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
@@ -28,8 +29,8 @@ class JSONEncoder(json.JSONEncoder):
return force_text(obj)
elif isinstance(obj, datetime.datetime):
representation = obj.isoformat()
- if representation.endswith('+00:00'):
- representation = representation[:-6] + 'Z'
+ if representation.endswith("+00:00"):
+ representation = representation[:-6] + "Z"
return representation
elif isinstance(obj, datetime.date):
return obj.isoformat()
@@ -49,20 +50,22 @@ class JSONEncoder(json.JSONEncoder):
return tuple(obj)
elif isinstance(obj, bytes):
# Best-effort for binary blobs. See #4187.
- return obj.decode('utf-8')
- elif hasattr(obj, 'tolist'):
+ return obj.decode("utf-8")
+ elif hasattr(obj, "tolist"):
# Numpy arrays and array scalars.
return obj.tolist()
- elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)):
+ elif (coreapi is not None) and isinstance(
+ obj, (coreapi.Document, coreapi.Error)
+ ):
raise RuntimeError(
- 'Cannot return a coreapi object from a JSON view. '
- 'You should be using a schema renderer instead for this view.'
+ "Cannot return a coreapi object from a JSON view. "
+ "You should be using a schema renderer instead for this view."
)
- elif hasattr(obj, '__getitem__'):
+ elif hasattr(obj, "__getitem__"):
try:
return dict(obj)
except Exception:
pass
- elif hasattr(obj, '__iter__'):
+ elif hasattr(obj, "__iter__"):
return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj)
diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py
index 927d08ff2..55a1fd603 100644
--- a/rest_framework/utils/field_mapping.py
+++ b/rest_framework/utils/field_mapping.py
@@ -11,8 +11,12 @@ from django.utils.text import capfirst
from rest_framework.compat import postgres_fields
from rest_framework.validators import UniqueValidator
+
NUMERIC_FIELD_TYPES = (
- models.IntegerField, models.FloatField, models.DecimalField, models.DurationField,
+ models.IntegerField,
+ models.FloatField,
+ models.DecimalField,
+ models.DurationField,
)
@@ -23,11 +27,12 @@ class ClassLookupDict(object):
hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
+
def __init__(self, mapping):
self.mapping = mapping
def __getitem__(self, key):
- if hasattr(key, '_proxy_class'):
+ if hasattr(key, "_proxy_class"):
# Deal with proxy classes. Ie. BoundField behaves as if it
# is a Field instance when using ClassLookupDict.
base_class = key._proxy_class
@@ -37,7 +42,7 @@ class ClassLookupDict(object):
for cls in inspect.getmro(base_class):
if cls in self.mapping:
return self.mapping[cls]
- raise KeyError('Class %s not found in lookup.' % base_class.__name__)
+ raise KeyError("Class %s not found in lookup." % base_class.__name__)
def __setitem__(self, key, value):
self.mapping[key] = value
@@ -48,7 +53,7 @@ def needs_label(model_field, field_name):
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
- default_label = field_name.replace('_', ' ').capitalize()
+ default_label = field_name.replace("_", " ").capitalize()
return capfirst(model_field.verbose_name) != default_label
@@ -57,9 +62,9 @@ def get_detail_view_name(model):
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
- return '%(model_name)s-detail' % {
- 'app_label': model._meta.app_label,
- 'model_name': model._meta.object_name.lower()
+ return "%(model_name)s-detail" % {
+ "app_label": model._meta.app_label,
+ "model_name": model._meta.object_name.lower(),
}
@@ -72,84 +77,98 @@ def get_field_kwargs(field_name, model_field):
# The following will only be used by ModelField classes.
# Gets removed for everything else.
- kwargs['model_field'] = model_field
+ kwargs["model_field"] = model_field
if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
+ kwargs["label"] = capfirst(model_field.verbose_name)
if model_field.help_text:
- kwargs['help_text'] = model_field.help_text
+ kwargs["help_text"] = model_field.help_text
- max_digits = getattr(model_field, 'max_digits', None)
+ max_digits = getattr(model_field, "max_digits", None)
if max_digits is not None:
- kwargs['max_digits'] = max_digits
+ kwargs["max_digits"] = max_digits
- decimal_places = getattr(model_field, 'decimal_places', None)
+ decimal_places = getattr(model_field, "decimal_places", None)
if decimal_places is not None:
- kwargs['decimal_places'] = decimal_places
+ kwargs["decimal_places"] = decimal_places
if isinstance(model_field, models.SlugField):
- kwargs['allow_unicode'] = model_field.allow_unicode
+ kwargs["allow_unicode"] = model_field.allow_unicode
- if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)):
- kwargs['style'] = {'base_template': 'textarea.html'}
+ if isinstance(model_field, models.TextField) or (
+ postgres_fields and isinstance(model_field, postgres_fields.JSONField)
+ ):
+ kwargs["style"] = {"base_template": "textarea.html"}
if isinstance(model_field, models.AutoField) or not model_field.editable:
# If this field is read-only, then return early.
# Further keyword arguments are not valid.
- kwargs['read_only'] = True
+ kwargs["read_only"] = True
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
- kwargs['required'] = False
+ kwargs["required"] = False
if model_field.null and not isinstance(model_field, models.NullBooleanField):
- kwargs['allow_null'] = True
+ kwargs["allow_null"] = True
- if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))):
- kwargs['allow_blank'] = True
+ if model_field.blank and (
+ isinstance(model_field, (models.CharField, models.TextField))
+ ):
+ kwargs["allow_blank"] = True
if isinstance(model_field, models.FilePathField):
- kwargs['path'] = model_field.path
+ kwargs["path"] = model_field.path
if model_field.match is not None:
- kwargs['match'] = model_field.match
+ kwargs["match"] = model_field.match
if model_field.recursive is not False:
- kwargs['recursive'] = model_field.recursive
+ kwargs["recursive"] = model_field.recursive
if model_field.allow_files is not True:
- kwargs['allow_files'] = model_field.allow_files
+ kwargs["allow_files"] = model_field.allow_files
if model_field.allow_folders is not False:
- kwargs['allow_folders'] = model_field.allow_folders
+ kwargs["allow_folders"] = model_field.allow_folders
if model_field.choices:
- kwargs['choices'] = model_field.choices
+ kwargs["choices"] = model_field.choices
else:
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
- max_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MaxValueValidator)
- ), None)
+ max_value = next(
+ (
+ validator.limit_value
+ for validator in validator_kwarg
+ if isinstance(validator, validators.MaxValueValidator)
+ ),
+ None,
+ )
if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
- kwargs['max_value'] = max_value
+ kwargs["max_value"] = max_value
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that min_value is passed explicitly as a keyword arg,
# rather than as a validator.
- min_value = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MinValueValidator)
- ), None)
+ min_value = next(
+ (
+ validator.limit_value
+ for validator in validator_kwarg
+ if isinstance(validator, validators.MinValueValidator)
+ ),
+ None,
+ )
if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
- kwargs['min_value'] = min_value
+ kwargs["min_value"] = min_value
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
@@ -157,7 +176,8 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
@@ -165,67 +185,79 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if validator is not validators.validate_slug
]
# IPAddressField do not need to include the 'validate_ipv46_address' argument,
if isinstance(model_field, models.GenericIPAddressField):
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if validator is not validators.validate_ipv46_address
]
# Our decimal validation is handled in the field code, not validator code.
if isinstance(model_field, models.DecimalField):
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.DecimalValidator)
]
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
- max_length = getattr(model_field, 'max_length', None)
- if max_length is not None and (isinstance(model_field, (models.CharField, models.TextField, models.FileField))):
- kwargs['max_length'] = max_length
+ max_length = getattr(model_field, "max_length", None)
+ if max_length is not None and (
+ isinstance(model_field, (models.CharField, models.TextField, models.FileField))
+ ):
+ kwargs["max_length"] = max_length
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
- min_length = next((
- validator.limit_value for validator in validator_kwarg
- if isinstance(validator, validators.MinLengthValidator)
- ), None)
+ min_length = next(
+ (
+ validator.limit_value
+ for validator in validator_kwarg
+ if isinstance(validator, validators.MinLengthValidator)
+ ),
+ None,
+ )
if min_length is not None and isinstance(model_field, models.CharField):
- kwargs['min_length'] = min_length
+ kwargs["min_length"] = min_length
validator_kwarg = [
- validator for validator in validator_kwarg
+ validator
+ for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
- if getattr(model_field, 'unique', False):
- unique_error_message = model_field.error_messages.get('unique', None)
+ if getattr(model_field, "unique", False):
+ unique_error_message = model_field.error_messages.get("unique", None)
if unique_error_message:
unique_error_message = unique_error_message % {
- 'model_name': model_field.model._meta.verbose_name,
- 'field_label': model_field.verbose_name
+ "model_name": model_field.model._meta.verbose_name,
+ "field_label": model_field.verbose_name,
}
validator = UniqueValidator(
- queryset=model_field.model._default_manager,
- message=unique_error_message)
+ queryset=model_field.model._default_manager, message=unique_error_message
+ )
validator_kwarg.append(validator)
if validator_kwarg:
- kwargs['validators'] = validator_kwarg
+ kwargs["validators"] = validator_kwarg
return kwargs
@@ -234,65 +266,65 @@ def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
- model_field, related_model, to_many, to_field, has_through_model, reverse = relation_info
+ model_field, related_model, to_many, to_field, has_through_model, reverse = (
+ relation_info
+ )
kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': get_detail_view_name(related_model)
+ "queryset": related_model._default_manager,
+ "view_name": get_detail_view_name(related_model),
}
if to_many:
- kwargs['many'] = True
+ kwargs["many"] = True
if to_field:
- kwargs['to_field'] = to_field
+ kwargs["to_field"] = to_field
limit_choices_to = model_field and model_field.get_limit_choices_to()
if limit_choices_to:
if not isinstance(limit_choices_to, models.Q):
limit_choices_to = models.Q(**limit_choices_to)
- kwargs['queryset'] = kwargs['queryset'].filter(limit_choices_to)
+ kwargs["queryset"] = kwargs["queryset"].filter(limit_choices_to)
if has_through_model:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
+ kwargs["read_only"] = True
+ kwargs.pop("queryset", None)
if model_field:
if model_field.verbose_name and needs_label(model_field, field_name):
- kwargs['label'] = capfirst(model_field.verbose_name)
+ kwargs["label"] = capfirst(model_field.verbose_name)
help_text = model_field.help_text
if help_text:
- kwargs['help_text'] = help_text
+ kwargs["help_text"] = help_text
if not model_field.editable:
- kwargs['read_only'] = True
- kwargs.pop('queryset', None)
- if kwargs.get('read_only', False):
+ kwargs["read_only"] = True
+ kwargs.pop("queryset", None)
+ if kwargs.get("read_only", False):
# If this field is read-only, then return early.
# No further keyword arguments are valid.
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
- kwargs['required'] = False
+ kwargs["required"] = False
if model_field.null:
- kwargs['allow_null'] = True
+ kwargs["allow_null"] = True
if model_field.validators:
- kwargs['validators'] = model_field.validators
- if getattr(model_field, 'unique', False):
+ kwargs["validators"] = model_field.validators
+ if getattr(model_field, "unique", False):
validator = UniqueValidator(queryset=model_field.model._default_manager)
- kwargs['validators'] = kwargs.get('validators', []) + [validator]
+ kwargs["validators"] = kwargs.get("validators", []) + [validator]
if to_many and not model_field.blank:
- kwargs['allow_empty'] = False
+ kwargs["allow_empty"] = False
return kwargs
def get_nested_relation_kwargs(relation_info):
- kwargs = {'read_only': True}
+ kwargs = {"read_only": True}
if relation_info.to_many:
- kwargs['many'] = True
+ kwargs["many"] = True
return kwargs
def get_url_kwargs(model_field):
- return {
- 'view_name': get_detail_view_name(model_field)
- }
+ return {"view_name": get_detail_view_name(model_field)}
diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py
index aa805f14e..0fcfa634a 100644
--- a/rest_framework/utils/formatting.py
+++ b/rest_framework/utils/formatting.py
@@ -18,7 +18,7 @@ def remove_trailing_string(content, trailing):
Used when generating names from view classes.
"""
if content.endswith(trailing) and content != trailing:
- return content[:-len(trailing)]
+ return content[: -len(trailing)]
return content
@@ -36,14 +36,14 @@ def dedent(content):
# unindent the content if needed
if lines:
- whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines])
- tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines])
+ whitespace_counts = min([len(line) - len(line.lstrip(" ")) for line in lines])
+ tab_counts = min([len(line) - len(line.lstrip("\t")) for line in lines])
if whitespace_counts:
- whitespace_pattern = '^' + (' ' * whitespace_counts)
- content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
+ whitespace_pattern = "^" + (" " * whitespace_counts)
+ content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
elif tab_counts:
- whitespace_pattern = '^' + ('\t' * tab_counts)
- content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
+ whitespace_pattern = "^" + ("\t" * tab_counts)
+ content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
return content.strip()
@@ -52,9 +52,9 @@ def camelcase_to_spaces(content):
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
"""
- camelcase_boundary = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
- content = re.sub(camelcase_boundary, ' \\1', content).strip()
- return ' '.join(content.split('_')).title()
+ camelcase_boundary = "(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))"
+ content = re.sub(camelcase_boundary, " \\1", content).strip()
+ return " ".join(content.split("_")).title()
def markup_description(description):
@@ -64,6 +64,6 @@ def markup_description(description):
if apply_markdown:
description = apply_markdown(description)
else:
- description = escape(description).replace('\n', '
')
- description = '' + description + '
'
+ description = escape(description).replace("\n", "
")
+ description = "" + description + "
"
return mark_safe(description)
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
index c7ede7803..48820327f 100644
--- a/rest_framework/utils/html.py
+++ b/rest_framework/utils/html.py
@@ -9,10 +9,10 @@ from django.utils.datastructures import MultiValueDict
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key.
- return hasattr(dictionary, 'getlist')
+ return hasattr(dictionary, "getlist")
-def parse_html_list(dictionary, prefix='', default=None):
+def parse_html_list(dictionary, prefix="", default=None):
"""
Used to support list values in HTML forms.
Supports lists of primitives and/or dictionaries.
@@ -48,7 +48,7 @@ def parse_html_list(dictionary, prefix='', default=None):
:returns a list of objects, or the value specified in ``default`` if the list is empty
"""
ret = {}
- regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
+ regex = re.compile(r"^%s\[([0-9]+)\](.*)$" % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
@@ -66,7 +66,7 @@ def parse_html_list(dictionary, prefix='', default=None):
return [ret[item] for item in sorted(ret)] if ret else default
-def parse_html_dict(dictionary, prefix=''):
+def parse_html_dict(dictionary, prefix=""):
"""
Used to support dictionary values in HTML forms.
@@ -83,7 +83,7 @@ def parse_html_dict(dictionary, prefix=''):
}
"""
ret = MultiValueDict()
- regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
+ regex = re.compile(r"^%s\.(.+)$" % re.escape(prefix))
for field in dictionary:
match = regex.match(field)
if not match:
diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py
index 48ef89547..14ca1f663 100644
--- a/rest_framework/utils/humanize_datetime.py
+++ b/rest_framework/utils/humanize_datetime.py
@@ -5,20 +5,19 @@ from rest_framework import ISO_8601
def datetime_formats(formats):
- format = ', '.join(formats).replace(
- ISO_8601,
- 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
+ format = ", ".join(formats).replace(
+ ISO_8601, "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"
)
return humanize_strptime(format)
def date_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DD')
+ format = ", ".join(formats).replace(ISO_8601, "YYYY-MM-DD")
return humanize_strptime(format)
def time_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
+ format = ", ".join(formats).replace(ISO_8601, "hh:mm[:ss[.uuuuuu]]")
return humanize_strptime(format)
@@ -40,7 +39,7 @@ def humanize_strptime(format_string):
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
- "%z": "[+HHMM|-HHMM]"
+ "%z": "[+HHMM|-HHMM]",
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
diff --git a/rest_framework/utils/json.py b/rest_framework/utils/json.py
index cb5572380..09ba12fe2 100644
--- a/rest_framework/utils/json.py
+++ b/rest_framework/utils/json.py
@@ -13,28 +13,28 @@ import json # noqa
def strict_constant(o):
- raise ValueError('Out of range float values are not JSON compliant: ' + repr(o))
+ raise ValueError("Out of range float values are not JSON compliant: " + repr(o))
@functools.wraps(json.dump)
def dump(*args, **kwargs):
- kwargs.setdefault('allow_nan', False)
+ kwargs.setdefault("allow_nan", False)
return json.dump(*args, **kwargs)
@functools.wraps(json.dumps)
def dumps(*args, **kwargs):
- kwargs.setdefault('allow_nan', False)
+ kwargs.setdefault("allow_nan", False)
return json.dumps(*args, **kwargs)
@functools.wraps(json.load)
def load(*args, **kwargs):
- kwargs.setdefault('parse_constant', strict_constant)
+ kwargs.setdefault("parse_constant", strict_constant)
return json.load(*args, **kwargs)
@functools.wraps(json.loads)
def loads(*args, **kwargs):
- kwargs.setdefault('parse_constant', strict_constant)
+ kwargs.setdefault("parse_constant", strict_constant)
return json.loads(*args, **kwargs)
diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py
index f4acf4807..2a5752d8f 100644
--- a/rest_framework/utils/mediatypes.py
+++ b/rest_framework/utils/mediatypes.py
@@ -49,20 +49,30 @@ def order_by_precedence(media_type_lst):
@python_2_unicode_compatible
class _MediaType(object):
def __init__(self, media_type_str):
- self.orig = '' if (media_type_str is None) else media_type_str
- self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING))
- self.main_type, sep, self.sub_type = self.full_type.partition('/')
+ self.orig = "" if (media_type_str is None) else media_type_str
+ self.full_type, self.params = parse_header(
+ self.orig.encode(HTTP_HEADER_ENCODING)
+ )
+ self.main_type, sep, self.sub_type = self.full_type.partition("/")
def match(self, other):
"""Return true if this MediaType satisfies the given MediaType."""
for key in self.params:
- if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
+ if key != "q" and other.params.get(key, None) != self.params.get(key, None):
return False
- if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
+ if (
+ self.sub_type != "*"
+ and other.sub_type != "*"
+ and other.sub_type != self.sub_type
+ ):
return False
- if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
+ if (
+ self.main_type != "*"
+ and other.main_type != "*"
+ and other.main_type != self.main_type
+ ):
return False
return True
@@ -72,16 +82,16 @@ class _MediaType(object):
"""
Return a precedence level from 0-3 for the media type given how specific it is.
"""
- if self.main_type == '*':
+ if self.main_type == "*":
return 0
- elif self.sub_type == '*':
+ elif self.sub_type == "*":
return 1
- elif not self.params or list(self.params) == ['q']:
+ elif not self.params or list(self.params) == ["q"]:
return 2
return 3
def __str__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items():
- ret += "; %s=%s" % (key, val.decode('ascii'))
+ ret += "; %s=%s" % (key, val.decode("ascii"))
return ret
diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py
index 4cc93b8ef..1bc8f1551 100644
--- a/rest_framework/utils/model_meta.py
+++ b/rest_framework/utils/model_meta.py
@@ -7,23 +7,30 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import OrderedDict, namedtuple
-FieldInfo = namedtuple('FieldResult', [
- 'pk', # Model field instance
- 'fields', # Dict of field name -> model field instance
- 'forward_relations', # Dict of field name -> RelationInfo
- 'reverse_relations', # Dict of field name -> RelationInfo
- 'fields_and_pk', # Shortcut for 'pk' + 'fields'
- 'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
-])
-RelationInfo = namedtuple('RelationInfo', [
- 'model_field',
- 'related_model',
- 'to_many',
- 'to_field',
- 'has_through_model',
- 'reverse'
-])
+FieldInfo = namedtuple(
+ "FieldResult",
+ [
+ "pk", # Model field instance
+ "fields", # Dict of field name -> model field instance
+ "forward_relations", # Dict of field name -> RelationInfo
+ "reverse_relations", # Dict of field name -> RelationInfo
+ "fields_and_pk", # Shortcut for 'pk' + 'fields'
+ "relations", # Shortcut for 'forward_relations' + 'reverse_relations'
+ ],
+)
+
+RelationInfo = namedtuple(
+ "RelationInfo",
+ [
+ "model_field",
+ "related_model",
+ "to_many",
+ "to_field",
+ "has_through_model",
+ "reverse",
+ ],
+)
def get_field_info(model):
@@ -41,8 +48,9 @@ def get_field_info(model):
fields_and_pk = _merge_fields_and_pk(pk, fields)
relationships = _merge_relationships(forward_relations, reverse_relations)
- return FieldInfo(pk, fields, forward_relations, reverse_relations,
- fields_and_pk, relationships)
+ return FieldInfo(
+ pk, fields, forward_relations, reverse_relations, fields_and_pk, relationships
+ )
def _get_pk(opts):
@@ -59,14 +67,16 @@ def _get_pk(opts):
def _get_fields(opts):
fields = OrderedDict()
- for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
+ for field in [
+ field for field in opts.fields if field.serialize and not field.remote_field
+ ]:
fields[field.name] = field
return fields
def _get_to_field(field):
- return getattr(field, 'to_fields', None) and field.to_fields[0]
+ return getattr(field, "to_fields", None) and field.to_fields[0]
def _get_forward_relationships(opts):
@@ -74,14 +84,16 @@ def _get_forward_relationships(opts):
Returns an `OrderedDict` of field names to `RelationInfo`.
"""
forward_relations = OrderedDict()
- for field in [field for field in opts.fields if field.serialize and field.remote_field]:
+ for field in [
+ field for field in opts.fields if field.serialize and field.remote_field
+ ]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related_model=field.remote_field.model,
to_many=False,
to_field=_get_to_field(field),
has_through_model=False,
- reverse=False
+ reverse=False,
)
# Deal with forward many-to-many relationships.
@@ -92,10 +104,8 @@ def _get_forward_relationships(opts):
to_many=True,
# manytomany do not have to_fields
to_field=None,
- has_through_model=(
- not field.remote_field.through._meta.auto_created
- ),
- reverse=False
+ has_through_model=(not field.remote_field.through._meta.auto_created),
+ reverse=False,
)
return forward_relations
@@ -115,11 +125,13 @@ def _get_reverse_relationships(opts):
to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field),
has_through_model=False,
- reverse=True
+ reverse=True,
)
# Deal with reverse many-to-many relationships.
- all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many]
+ all_related_many_to_many_objects = [
+ r for r in opts.related_objects if r.field.many_to_many
+ ]
for relation in all_related_many_to_many_objects:
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
@@ -129,10 +141,10 @@ def _get_reverse_relationships(opts):
# manytomany do not have to_fields
to_field=None,
has_through_model=(
- (getattr(relation.field.remote_field, 'through', None) is not None) and
- not relation.field.remote_field.through._meta.auto_created
+ (getattr(relation.field.remote_field, "through", None) is not None)
+ and not relation.field.remote_field.through._meta.auto_created
),
- reverse=True
+ reverse=True,
)
return reverse_relations
@@ -140,7 +152,7 @@ def _get_reverse_relationships(opts):
def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict()
- fields_and_pk['pk'] = pk
+ fields_and_pk["pk"] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
@@ -149,8 +161,7 @@ def _merge_fields_and_pk(pk, fields):
def _merge_relationships(forward_relations, reverse_relations):
return OrderedDict(
- list(forward_relations.items()) +
- list(reverse_relations.items())
+ list(forward_relations.items()) + list(reverse_relations.items())
)
@@ -158,4 +169,8 @@ def is_abstract_model(model):
"""
Given a model class, returns a boolean True if it is abstract and False if it is not.
"""
- return hasattr(model, '_meta') and hasattr(model._meta, 'abstract') and model._meta.abstract
+ return (
+ hasattr(model, "_meta")
+ and hasattr(model._meta, "abstract")
+ and model._meta.abstract
+ )
diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py
index deeaf1f63..43a1bb3c2 100644
--- a/rest_framework/utils/representation.py
+++ b/rest_framework/utils/representation.py
@@ -16,14 +16,10 @@ from rest_framework.compat import unicode_repr
def manager_repr(value):
model = value.model
opts = model._meta
- names_and_managers = [
- (manager.name, manager)
- for manager
- in opts.managers
- ]
+ names_and_managers = [(manager.name, manager) for manager in opts.managers]
for manager_name, manager_instance in names_and_managers:
if manager_instance == value:
- return '%s.%s.all()' % (model._meta.object_name, manager_name)
+ return "%s.%s.all()" % (model._meta.object_name, manager_name)
return repr(value)
@@ -45,7 +41,7 @@ def smart_repr(value):
#
# Should be presented as
#
- value = re.sub(' at 0x[0-9A-Fa-f]{4,32}>', '>', value)
+ value = re.sub(" at 0x[0-9A-Fa-f]{4,32}>", ">", value)
return value
@@ -54,16 +50,15 @@ def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
- kwargs['many'] = True
- kwargs.pop('child', None)
+ kwargs["many"] = True
+ kwargs.pop("child", None)
- arg_string = ', '.join([smart_repr(val) for val in field._args])
- kwarg_string = ', '.join([
- '%s=%s' % (key, smart_repr(val))
- for key, val in sorted(kwargs.items())
- ])
+ arg_string = ", ".join([smart_repr(val) for val in field._args])
+ kwarg_string = ", ".join(
+ ["%s=%s" % (key, smart_repr(val)) for key, val in sorted(kwargs.items())]
+ )
if arg_string and kwarg_string:
- arg_string += ', '
+ arg_string += ", "
if force_many:
class_name = force_many.__class__.__name__
@@ -74,8 +69,8 @@ def field_repr(field, force_many=False):
def serializer_repr(serializer, indent, force_many=None):
- ret = field_repr(serializer, force_many) + ':'
- indent_str = ' ' * indent
+ ret = field_repr(serializer, force_many) + ":"
+ indent_str = " " * indent
if force_many:
fields = force_many.fields
@@ -83,25 +78,27 @@ def serializer_repr(serializer, indent, force_many=None):
fields = serializer.fields
for field_name, field in fields.items():
- ret += '\n' + indent_str + field_name + ' = '
- if hasattr(field, 'fields'):
+ ret += "\n" + indent_str + field_name + " = "
+ if hasattr(field, "fields"):
ret += serializer_repr(field, indent + 1)
- elif hasattr(field, 'child'):
+ elif hasattr(field, "child"):
ret += list_repr(field, indent + 1)
- elif hasattr(field, 'child_relation'):
+ elif hasattr(field, "child_relation"):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
if serializer.validators:
- ret += '\n' + indent_str + 'class Meta:'
- ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators)
+ ret += "\n" + indent_str + "class Meta:"
+ ret += (
+ "\n" + indent_str + " validators = " + smart_repr(serializer.validators)
+ )
return ret
def list_repr(serializer, indent):
child = serializer.child
- if hasattr(child, 'fields'):
+ if hasattr(child, "fields"):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)
diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py
index c24e51d09..e2db8c130 100644
--- a/rest_framework/utils/serializer_helpers.py
+++ b/rest_framework/utils/serializer_helpers.py
@@ -16,7 +16,7 @@ class ReturnDict(OrderedDict):
"""
def __init__(self, *args, **kwargs):
- self.serializer = kwargs.pop('serializer')
+ self.serializer = kwargs.pop("serializer")
super(ReturnDict, self).__init__(*args, **kwargs)
def copy(self):
@@ -39,7 +39,7 @@ class ReturnList(list):
"""
def __init__(self, *args, **kwargs):
- self.serializer = kwargs.pop('serializer')
+ self.serializer = kwargs.pop("serializer")
super(ReturnList, self).__init__(*args, **kwargs)
def __repr__(self):
@@ -58,7 +58,7 @@ class BoundField(object):
providing an API similar to Django forms and form fields.
"""
- def __init__(self, field, value, errors, prefix=''):
+ def __init__(self, field, value, errors, prefix=""):
self._field = field
self._prefix = prefix
self.value = value
@@ -73,12 +73,13 @@ class BoundField(object):
return self._field.__class__
def __repr__(self):
- return unicode_to_repr('<%s value=%s errors=%s>' % (
- self.__class__.__name__, self.value, self.errors
- ))
+ return unicode_to_repr(
+ "<%s value=%s errors=%s>"
+ % (self.__class__.__name__, self.value, self.errors)
+ )
def as_form_field(self):
- value = '' if (self.value is None or self.value is False) else self.value
+ value = "" if (self.value is None or self.value is False) else self.value
return self.__class__(self._field, value, self.errors, self._prefix)
@@ -87,7 +88,7 @@ class JSONBoundField(BoundField):
value = self.value
# When HTML form input is used and the input is not valid
# value will be a JSONString, rather than a JSON primitive.
- if not getattr(value, 'is_json_string', False):
+ if not getattr(value, "is_json_string", False):
try:
value = json.dumps(self.value, sort_keys=True, indent=4)
except (TypeError, ValueError):
@@ -102,8 +103,8 @@ class NestedBoundField(BoundField):
`BoundField` that is used for serializer fields.
"""
- def __init__(self, field, value, errors, prefix=''):
- if value is None or value is '':
+ def __init__(self, field, value, errors, prefix=""):
+ if value is None or value is "":
value = {}
super(NestedBoundField, self).__init__(field, value, errors, prefix)
@@ -115,9 +116,9 @@ class NestedBoundField(BoundField):
field = self.fields[key]
value = self.value.get(key) if self.value else None
error = self.errors.get(key) if isinstance(self.errors, dict) else None
- if hasattr(field, 'fields'):
- return NestedBoundField(field, value, error, prefix=self.name + '.')
- return BoundField(field, value, error, prefix=self.name + '.')
+ if hasattr(field, "fields"):
+ return NestedBoundField(field, value, error, prefix=self.name + ".")
+ return BoundField(field, value, error, prefix=self.name + ".")
def as_form_field(self):
values = {}
@@ -125,7 +126,9 @@ class NestedBoundField(BoundField):
if isinstance(value, (list, dict)):
values[key] = value
else:
- values[key] = '' if (value is None or value is False) else force_text(value)
+ values[key] = (
+ "" if (value is None or value is False) else force_text(value)
+ )
return self.__class__(self._field, values, self.errors, self._prefix)
diff --git a/rest_framework/validators.py b/rest_framework/validators.py
index 2ea3e5ac1..b2b06adf7 100644
--- a/rest_framework/validators.py
+++ b/rest_framework/validators.py
@@ -39,9 +39,10 @@ class UniqueValidator(object):
Should be applied to an individual field on the serializer.
"""
- message = _('This field must be unique.')
- def __init__(self, queryset, message=None, lookup='exact'):
+ message = _("This field must be unique.")
+
+ def __init__(self, queryset, message=None, lookup="exact"):
self.queryset = queryset
self.serializer_field = None
self.message = message or self.message
@@ -56,13 +57,13 @@ class UniqueValidator(object):
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
- self.instance = getattr(serializer_field.parent, 'instance', None)
+ self.instance = getattr(serializer_field.parent, "instance", None)
def filter_queryset(self, value, queryset):
"""
Filter the queryset to all instances matching the given attribute.
"""
- filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
+ filter_kwargs = {"%s__%s" % (self.field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset):
@@ -79,13 +80,12 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset):
- raise ValidationError(self.message, code='unique')
+ raise ValidationError(self.message, code="unique")
def __repr__(self):
- return unicode_to_repr('<%s(queryset=%s)>' % (
- self.__class__.__name__,
- smart_repr(self.queryset)
- ))
+ return unicode_to_repr(
+ "<%s(queryset=%s)>" % (self.__class__.__name__, smart_repr(self.queryset))
+ )
class UniqueTogetherValidator(object):
@@ -94,8 +94,9 @@ class UniqueTogetherValidator(object):
Should be applied to the serializer class, not to an individual field.
"""
- message = _('The fields {field_names} must make a unique set.')
- missing_message = _('This field is required.')
+
+ message = _("The fields {field_names} must make a unique set.")
+ missing_message = _("This field is required.")
def __init__(self, queryset, fields, message=None):
self.queryset = queryset
@@ -109,7 +110,7 @@ class UniqueTogetherValidator(object):
prior to the validation call being made.
"""
# Determine the existing instance, if this is an update operation.
- self.instance = getattr(serializer, 'instance', None)
+ self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs):
"""
@@ -125,7 +126,7 @@ class UniqueTogetherValidator(object):
if field_name not in attrs
}
if missing_items:
- raise ValidationError(missing_items, code='required')
+ raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset):
"""
@@ -139,10 +140,7 @@ class UniqueTogetherValidator(object):
attrs[field_name] = getattr(self.instance, field_name)
# Determine the filter keyword arguments and filter the queryset.
- filter_kwargs = {
- field_name: attrs[field_name]
- for field_name in self.fields
- }
+ filter_kwargs = {field_name: attrs[field_name] for field_name in self.fields}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset):
@@ -165,21 +163,24 @@ class UniqueTogetherValidator(object):
value for field, value in attrs.items() if field in self.fields
]
if None not in checked_values and qs_exists(queryset):
- field_names = ', '.join(self.fields)
+ field_names = ", ".join(self.fields)
message = self.message.format(field_names=field_names)
- raise ValidationError(message, code='unique')
+ raise ValidationError(message, code="unique")
def __repr__(self):
- return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
- self.__class__.__name__,
- smart_repr(self.queryset),
- smart_repr(self.fields)
- ))
+ return unicode_to_repr(
+ "<%s(queryset=%s, fields=%s)>"
+ % (
+ self.__class__.__name__,
+ smart_repr(self.queryset),
+ smart_repr(self.fields),
+ )
+ )
class BaseUniqueForValidator(object):
message = None
- missing_message = _('This field is required.')
+ missing_message = _("This field is required.")
def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset
@@ -197,7 +198,7 @@ class BaseUniqueForValidator(object):
self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
- self.instance = getattr(serializer, 'instance', None)
+ self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs):
"""
@@ -210,10 +211,10 @@ class BaseUniqueForValidator(object):
if field_name not in attrs
}
if missing_items:
- raise ValidationError(missing_items, code='required')
+ raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset):
- raise NotImplementedError('`filter_queryset` must be implemented.')
+ raise NotImplementedError("`filter_queryset` must be implemented.")
def exclude_current_instance(self, attrs, queryset):
"""
@@ -231,17 +232,18 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
- raise ValidationError({
- self.field: message
- }, code='unique')
+ raise ValidationError({self.field: message}, code="unique")
def __repr__(self):
- return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
- self.__class__.__name__,
- smart_repr(self.queryset),
- smart_repr(self.field),
- smart_repr(self.date_field)
- ))
+ return unicode_to_repr(
+ "<%s(queryset=%s, field=%s, date_field=%s)>"
+ % (
+ self.__class__.__name__,
+ smart_repr(self.queryset),
+ smart_repr(self.field),
+ smart_repr(self.date_field),
+ )
+ )
class UniqueForDateValidator(BaseUniqueForValidator):
@@ -253,9 +255,9 @@ class UniqueForDateValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
- filter_kwargs['%s__day' % self.date_field_name] = date.day
- filter_kwargs['%s__month' % self.date_field_name] = date.month
- filter_kwargs['%s__year' % self.date_field_name] = date.year
+ filter_kwargs["%s__day" % self.date_field_name] = date.day
+ filter_kwargs["%s__month" % self.date_field_name] = date.month
+ filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
@@ -268,7 +270,7 @@ class UniqueForMonthValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
- filter_kwargs['%s__month' % self.date_field_name] = date.month
+ filter_kwargs["%s__month" % self.date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs)
@@ -281,5 +283,5 @@ class UniqueForYearValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
- filter_kwargs['%s__year' % self.date_field_name] = date.year
+ filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py
index 206ff6c2e..4e3bf855f 100644
--- a/rest_framework/versioning.py
+++ b/rest_framework/versioning.py
@@ -19,19 +19,20 @@ class BaseVersioning(object):
version_param = api_settings.VERSION_PARAM
def determine_version(self, request, *args, **kwargs):
- msg = '{cls}.determine_version() must be implemented.'
- raise NotImplementedError(msg.format(
- cls=self.__class__.__name__
- ))
+ msg = "{cls}.determine_version() must be implemented."
+ raise NotImplementedError(msg.format(cls=self.__class__.__name__))
- def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ def reverse(
+ self, viewname, args=None, kwargs=None, request=None, format=None, **extra
+ ):
return _reverse(viewname, args, kwargs, request, format, **extra)
def is_allowed_version(self, version):
if not self.allowed_versions:
return True
- return ((version is not None and version == self.default_version) or
- (version in self.allowed_versions))
+ return (version is not None and version == self.default_version) or (
+ version in self.allowed_versions
+ )
class AcceptHeaderVersioning(BaseVersioning):
@@ -40,6 +41,7 @@ class AcceptHeaderVersioning(BaseVersioning):
Host: example.com
Accept: application/json; version=1.0
"""
+
invalid_version_message = _('Invalid version in "Accept" header.')
def determine_version(self, request, *args, **kwargs):
@@ -71,7 +73,8 @@ class URLPathVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
- invalid_version_message = _('Invalid version in URL path.')
+
+ invalid_version_message = _("Invalid version in URL path.")
def determine_version(self, request, *args, **kwargs):
version = kwargs.get(self.version_param, self.default_version)
@@ -82,7 +85,9 @@ class URLPathVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message)
return version
- def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ def reverse(
+ self, viewname, args=None, kwargs=None, request=None, format=None, **extra
+ ):
if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version
@@ -116,21 +121,26 @@ class NamespaceVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
- invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.')
+
+ invalid_version_message = _(
+ "Invalid version in URL path. Does not match any version namespace."
+ )
def determine_version(self, request, *args, **kwargs):
- resolver_match = getattr(request, 'resolver_match', None)
+ resolver_match = getattr(request, "resolver_match", None)
if resolver_match is None or not resolver_match.namespace:
return self.default_version
# Allow for possibly nested namespaces.
- possible_versions = resolver_match.namespace.split(':')
+ possible_versions = resolver_match.namespace.split(":")
for version in possible_versions:
if self.is_allowed_version(version):
return version
raise exceptions.NotFound(self.invalid_version_message)
- def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ def reverse(
+ self, viewname, args=None, kwargs=None, request=None, format=None, **extra
+ ):
if request.version is not None:
viewname = self.get_versioned_viewname(viewname, request)
return super(NamespaceVersioning, self).reverse(
@@ -138,7 +148,7 @@ class NamespaceVersioning(BaseVersioning):
)
def get_versioned_viewname(self, viewname, request):
- return request.version + ':' + viewname
+ return request.version + ":" + viewname
class HostNameVersioning(BaseVersioning):
@@ -147,11 +157,12 @@ class HostNameVersioning(BaseVersioning):
Host: v1.example.com
Accept: application/json
"""
- hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
- invalid_version_message = _('Invalid version in hostname.')
+
+ hostname_regex = re.compile(r"^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$")
+ invalid_version_message = _("Invalid version in hostname.")
def determine_version(self, request, *args, **kwargs):
- hostname, separator, port = request.get_host().partition(':')
+ hostname, separator, port = request.get_host().partition(":")
match = self.hostname_regex.match(hostname)
if not match:
return self.default_version
@@ -170,7 +181,8 @@ class QueryParameterVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
- invalid_version_message = _('Invalid version in query parameter.')
+
+ invalid_version_message = _("Invalid version in query parameter.")
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get(self.version_param, self.default_version)
@@ -178,7 +190,9 @@ class QueryParameterVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message)
return version
- def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
+ def reverse(
+ self, viewname, args=None, kwargs=None, request=None, format=None, **extra
+ ):
url = super(QueryParameterVersioning, self).reverse(
viewname, args, kwargs, request, format, **extra
)
diff --git a/rest_framework/views.py b/rest_framework/views.py
index 9d5d959e9..d4c006dfe 100644
--- a/rest_framework/views.py
+++ b/rest_framework/views.py
@@ -29,19 +29,19 @@ def get_view_name(view):
This function is the default for the `VIEW_NAME_FUNCTION` setting.
"""
# Name may be set by some Views, such as a ViewSet.
- name = getattr(view, 'name', None)
+ name = getattr(view, "name", None)
if name is not None:
return name
name = view.__class__.__name__
- name = formatting.remove_trailing_string(name, 'View')
- name = formatting.remove_trailing_string(name, 'ViewSet')
+ name = formatting.remove_trailing_string(name, "View")
+ name = formatting.remove_trailing_string(name, "ViewSet")
name = formatting.camelcase_to_spaces(name)
# Suffix may be set by some Views, such as a ViewSet.
- suffix = getattr(view, 'suffix', None)
+ suffix = getattr(view, "suffix", None)
if suffix:
- name += ' ' + suffix
+ name += " " + suffix
return name
@@ -54,9 +54,9 @@ def get_view_description(view, html=False):
This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
"""
# Description may be set by some Views, such as a ViewSet.
- description = getattr(view, 'description', None)
+ description = getattr(view, "description", None)
if description is None:
- description = view.__class__.__doc__ or ''
+ description = view.__class__.__doc__ or ""
description = formatting.dedent(smart_text(description))
if html:
@@ -65,7 +65,7 @@ def get_view_description(view, html=False):
def set_rollback():
- atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
+ atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
@@ -87,15 +87,15 @@ def exception_handler(exc, context):
if isinstance(exc, exceptions.APIException):
headers = {}
- if getattr(exc, 'auth_header', None):
- headers['WWW-Authenticate'] = exc.auth_header
- if getattr(exc, 'wait', None):
- headers['Retry-After'] = '%d' % exc.wait
+ if getattr(exc, "auth_header", None):
+ headers["WWW-Authenticate"] = exc.auth_header
+ if getattr(exc, "wait", None):
+ headers["Retry-After"] = "%d" % exc.wait
if isinstance(exc.detail, (list, dict)):
data = exc.detail
else:
- data = {'detail': exc.detail}
+ data = {"detail": exc.detail}
set_rollback()
return Response(data, status=exc.status_code, headers=headers)
@@ -128,13 +128,15 @@ class APIView(View):
This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation.
"""
- if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
+ if isinstance(getattr(cls, "queryset", None), models.query.QuerySet):
+
def force_evaluation():
raise RuntimeError(
- 'Do not evaluate the `.queryset` attribute directly, '
- 'as the result will be cached and reused between requests. '
- 'Use `.all()` or call `.get_queryset()` instead.'
+ "Do not evaluate the `.queryset` attribute directly, "
+ "as the result will be cached and reused between requests. "
+ "Use `.all()` or call `.get_queryset()` instead."
)
+
cls.queryset._fetch_all = force_evaluation
view = super(APIView, cls).as_view(**initkwargs)
@@ -154,11 +156,9 @@ class APIView(View):
@property
def default_response_headers(self):
- headers = {
- 'Allow': ', '.join(self.allowed_methods),
- }
+ headers = {"Allow": ", ".join(self.allowed_methods)}
if len(self.renderer_classes) > 1:
- headers['Vary'] = 'Accept'
+ headers["Vary"] = "Accept"
return headers
def http_method_not_allowed(self, request, *args, **kwargs):
@@ -199,9 +199,9 @@ class APIView(View):
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
- 'view': self,
- 'args': getattr(self, 'args', ()),
- 'kwargs': getattr(self, 'kwargs', {})
+ "view": self,
+ "args": getattr(self, "args", ()),
+ "kwargs": getattr(self, "kwargs", {}),
}
def get_renderer_context(self):
@@ -212,10 +212,10 @@ class APIView(View):
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
- 'view': self,
- 'args': getattr(self, 'args', ()),
- 'kwargs': getattr(self, 'kwargs', {}),
- 'request': getattr(self, 'request', None)
+ "view": self,
+ "args": getattr(self, "args", ()),
+ "kwargs": getattr(self, "kwargs", {}),
+ "request": getattr(self, "request", None),
}
def get_exception_handler_context(self):
@@ -224,10 +224,10 @@ class APIView(View):
as the `context` argument.
"""
return {
- 'view': self,
- 'args': getattr(self, 'args', ()),
- 'kwargs': getattr(self, 'kwargs', {}),
- 'request': getattr(self, 'request', None)
+ "view": self,
+ "args": getattr(self, "args", ()),
+ "kwargs": getattr(self, "kwargs", {}),
+ "request": getattr(self, "request", None),
}
def get_view_name(self):
@@ -289,7 +289,7 @@ class APIView(View):
"""
Instantiate and return the content negotiation class to use.
"""
- if not getattr(self, '_negotiator', None):
+ if not getattr(self, "_negotiator", None):
self._negotiator = self.content_negotiation_class()
return self._negotiator
@@ -333,7 +333,7 @@ class APIView(View):
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
- request, message=getattr(permission, 'message', None)
+ request, message=getattr(permission, "message", None)
)
def check_object_permissions(self, request, obj):
@@ -344,7 +344,7 @@ class APIView(View):
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
- request, message=getattr(permission, 'message', None)
+ request, message=getattr(permission, "message", None)
)
def check_throttles(self, request):
@@ -379,7 +379,7 @@ class APIView(View):
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
- parser_context=parser_context
+ parser_context=parser_context,
)
def initial(self, request, *args, **kwargs):
@@ -407,13 +407,12 @@ class APIView(View):
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
- 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
- 'to be returned from the view, but received a `%s`'
- % type(response)
+ "Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` "
+ "to be returned from the view, but received a `%s`" % type(response)
)
if isinstance(response, Response):
- if not getattr(request, 'accepted_renderer', None):
+ if not getattr(request, "accepted_renderer", None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
@@ -422,7 +421,7 @@ class APIView(View):
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
- vary_headers = self.headers.pop('Vary', None)
+ vary_headers = self.headers.pop("Vary", None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
@@ -436,8 +435,9 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
- if isinstance(exc, (exceptions.NotAuthenticated,
- exceptions.AuthenticationFailed)):
+ if isinstance(
+ exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)
+ ):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
@@ -460,8 +460,8 @@ class APIView(View):
def raise_uncaught_exception(self, exc):
if settings.DEBUG:
request = self.request
- renderer_format = getattr(request.accepted_renderer, 'format')
- use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin')
+ renderer_format = getattr(request.accepted_renderer, "format")
+ use_plaintext_traceback = renderer_format not in ("html", "api", "admin")
request.force_plaintext_errors(use_plaintext_traceback)
raise exc
@@ -484,8 +484,9 @@ class APIView(View):
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
- handler = getattr(self, request.method.lower(),
- self.http_method_not_allowed)
+ handler = getattr(
+ self, request.method.lower(), self.http_method_not_allowed
+ )
else:
handler = self.http_method_not_allowed
diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py
index 7146828d2..fd57789a8 100644
--- a/rest_framework/viewsets.py
+++ b/rest_framework/viewsets.py
@@ -31,7 +31,7 @@ from rest_framework.reverse import reverse
def _is_extra_action(attr):
- return hasattr(attr, 'mapping')
+ return hasattr(attr, "mapping")
class ViewSetMixin(object):
@@ -73,24 +73,30 @@ class ViewSetMixin(object):
# actions must not be empty
if not actions:
- raise TypeError("The `actions` argument must be provided when "
- "calling `.as_view()` on a ViewSet. For example "
- "`.as_view({'get': 'list'})`")
+ raise TypeError(
+ "The `actions` argument must be provided when "
+ "calling `.as_view()` on a ViewSet. For example "
+ "`.as_view({'get': 'list'})`"
+ )
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
- raise TypeError("You tried to pass in the %s method name as a "
- "keyword argument to %s(). Don't do that."
- % (key, cls.__name__))
+ raise TypeError(
+ "You tried to pass in the %s method name as a "
+ "keyword argument to %s(). Don't do that." % (key, cls.__name__)
+ )
if not hasattr(cls, key):
- raise TypeError("%s() received an invalid keyword %r" % (
- cls.__name__, key))
+ raise TypeError(
+ "%s() received an invalid keyword %r" % (cls.__name__, key)
+ )
# name and suffix are mutually exclusive
- if 'name' in initkwargs and 'suffix' in initkwargs:
- raise TypeError("%s() received both `name` and `suffix`, which are "
- "mutually exclusive arguments." % (cls.__name__))
+ if "name" in initkwargs and "suffix" in initkwargs:
+ raise TypeError(
+ "%s() received both `name` and `suffix`, which are "
+ "mutually exclusive arguments." % (cls.__name__)
+ )
def view(request, *args, **kwargs):
self = cls(**initkwargs)
@@ -105,7 +111,7 @@ class ViewSetMixin(object):
handler = getattr(self, action)
setattr(self, method, handler)
- if hasattr(self, 'get') and not hasattr(self, 'head'):
+ if hasattr(self, "get") and not hasattr(self, "head"):
self.head = self.get
self.request = request
@@ -136,11 +142,11 @@ class ViewSetMixin(object):
"""
request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs)
method = request.method.lower()
- if method == 'options':
+ if method == "options":
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
- self.action = 'metadata'
+ self.action = "metadata"
else:
self.action = self.action_map.get(method)
return request
@@ -149,8 +155,8 @@ class ViewSetMixin(object):
"""
Reverse the action for the given `url_name`.
"""
- url_name = '%s-%s' % (self.basename, url_name)
- kwargs.setdefault('request', self.request)
+ url_name = "%s-%s" % (self.basename, url_name)
+ kwargs.setdefault("request", self.request)
return reverse(url_name, *args, **kwargs)
@@ -175,13 +181,14 @@ class ViewSetMixin(object):
# filter for the relevant extra actions
actions = [
- action for action in self.get_extra_actions()
+ action
+ for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
- url_name = '%s-%s' % (self.basename, action.url_name)
+ url_name = "%s-%s" % (self.basename, action.url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
@@ -195,6 +202,7 @@ class ViewSet(ViewSetMixin, views.APIView):
"""
The base ViewSet class does not provide any actions by default.
"""
+
pass
@@ -204,26 +212,31 @@ class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
but does include the base set of generic view behavior, such as
the `get_object` and `get_queryset` methods.
"""
+
pass
-class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
- mixins.ListModelMixin,
- GenericViewSet):
+class ReadOnlyModelViewSet(
+ mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet
+):
"""
A viewset that provides default `list()` and `retrieve()` actions.
"""
+
pass
-class ModelViewSet(mixins.CreateModelMixin,
- mixins.RetrieveModelMixin,
- mixins.UpdateModelMixin,
- mixins.DestroyModelMixin,
- mixins.ListModelMixin,
- GenericViewSet):
+class ModelViewSet(
+ mixins.CreateModelMixin,
+ mixins.RetrieveModelMixin,
+ mixins.UpdateModelMixin,
+ mixins.DestroyModelMixin,
+ mixins.ListModelMixin,
+ GenericViewSet,
+):
"""
A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions.
"""
+
pass
diff --git a/runtests.py b/runtests.py
index 4dc475375..a5c30e11a 100755
--- a/runtests.py
+++ b/runtests.py
@@ -6,16 +6,23 @@ import sys
import pytest
-PYTEST_ARGS = {
- 'default': [],
- 'fast': ['-q'],
-}
-FLAKE8_ARGS = ['rest_framework', 'tests']
+PYTEST_ARGS = {"default": [], "fast": ["-q"]}
-ISORT_ARGS = ['--recursive', '--check-only', '--diff', '-o' 'uritemplate', '-p', 'tests', 'rest_framework', 'tests']
+FLAKE8_ARGS = ["rest_framework", "tests"]
-BLACK_ARGS = ['--check', '--verbose']
+ISORT_ARGS = [
+ "--recursive",
+ "--check-only",
+ "--diff",
+ "-o" "uritemplate",
+ "-p",
+ "tests",
+ "rest_framework",
+ "tests",
+]
+
+BLACK_ARGS = ["--check", "--verbose"]
def exit_on_failure(ret, message=None):
@@ -24,43 +31,48 @@ def exit_on_failure(ret, message=None):
def flake8_main(args):
- print('Running flake8 code linting')
- ret = subprocess.call(['flake8'] + args)
- print('flake8 failed' if ret else 'flake8 passed')
+ print("Running flake8 code linting")
+ ret = subprocess.call(["flake8"] + args)
+ print("flake8 failed" if ret else "flake8 passed")
return ret
def isort_main(args):
- print('Running isort code checking')
- ret = subprocess.call(['isort'] + args)
+ print("Running isort code checking")
+ ret = subprocess.call(["isort"] + args)
if ret:
- print('isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`')
+ print(
+ "isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`"
+ )
else:
- print('isort passed')
+ print("isort passed")
return ret
def black_main(args):
- print('Running black code checking')
- ret = subprocess.call(['black', '.'] + args)
+ print("Running black code checking")
+ ret = subprocess.call(["black", "."] + args)
if ret:
- print('black failed: Some code have incorrectly formatted. Fix by running `black .`')
+ print(
+ "black failed: Some code have incorrectly formatted. Fix by running `black .`"
+ )
else:
- print('black passed')
+ print("black passed")
return ret
+
def split_class_and_function(string):
- class_string, function_string = string.split('.', 1)
+ class_string, function_string = string.split(".", 1)
return "%s and %s" % (class_string, function_string)
def is_function(string):
# `True` if it looks like a test function is included in the string.
- return string.startswith('test_') or '.test_' in string
+ return string.startswith("test_") or ".test_" in string
def is_class(string):
@@ -70,7 +82,7 @@ def is_class(string):
if __name__ == "__main__":
try:
- sys.argv.remove('--nolint')
+ sys.argv.remove("--nolint")
except ValueError:
run_black = True
run_flake8 = True
@@ -81,18 +93,18 @@ if __name__ == "__main__":
run_isort = False
try:
- sys.argv.remove('--lintonly')
+ sys.argv.remove("--lintonly")
except ValueError:
run_tests = True
else:
run_tests = False
try:
- sys.argv.remove('--fast')
+ sys.argv.remove("--fast")
except ValueError:
- style = 'default'
+ style = "default"
else:
- style = 'fast'
+ style = "fast"
run_black = False
run_flake8 = False
run_isort = False
@@ -102,26 +114,23 @@ if __name__ == "__main__":
first_arg = pytest_args[0]
try:
- pytest_args.remove('--coverage')
+ pytest_args.remove("--coverage")
except ValueError:
pass
else:
- pytest_args = [
- '--cov', '.',
- '--cov-report', 'xml',
- ] + pytest_args
+ pytest_args = ["--cov", ".", "--cov-report", "xml"] + pytest_args
- if first_arg.startswith('-'):
+ if first_arg.startswith("-"):
# `runtests.py [flags]`
- pytest_args = ['tests'] + pytest_args
+ pytest_args = ["tests"] + pytest_args
elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg)
- pytest_args = ['tests', '-k', expression] + pytest_args[1:]
+ pytest_args = ["tests", "-k", expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]`
- pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:]
+ pytest_args = ["tests", "-k", pytest_args[0]] + pytest_args[1:]
else:
pytest_args = PYTEST_ARGS[style]
diff --git a/setup.cfg b/setup.cfg
index c95134600..8056bfe93 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -9,16 +9,23 @@ addopts=--tb=short --strict -ra
testspath = tests
[flake8]
-ignore = E501
+max-line-length = 120
+ignore = E501, W503, E203
banned-modules = json = use from rest_framework.utils import json!
[isort]
skip=.tox
atomic=true
-multi_line_output=5
-known_standard_library=types
+multi_line_output=3
+lines_after_imports = 2
+black=types
+combine_as_imports = true
known_third_party=pytest,_pytest,django,pytz
-known_first_party=rest_framework
+known_first_party=rest_framework, tests
+include_trailing_comma=true
+line_length = 88
+balanced_wrapping = true
+sections = FUTURE, STDLIB, DJANGO, CMS, THIRDPARTY, FIRSTPARTY, LIB, LOCALFOLDER
[coverage:run]
# NOTE: source is ignored with pytest-cov (but uses the same).
diff --git a/setup.py b/setup.py
index cb850a3ae..769f71322 100755
--- a/setup.py
+++ b/setup.py
@@ -10,21 +10,21 @@ from setuptools import find_packages, setup
def read(f):
- return open(f, 'r', encoding='utf-8').read()
+ return open(f, "r", encoding="utf-8").read()
def get_version(package):
"""
Return package version as listed in `__version__` in `init.py`.
"""
- init_py = open(os.path.join(package, '__init__.py')).read()
+ init_py = open(os.path.join(package, "__init__.py")).read()
return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
-version = get_version('rest_framework')
+version = get_version("rest_framework")
-if sys.argv[-1] == 'publish':
+if sys.argv[-1] == "publish":
if os.system("pip freeze | grep twine"):
print("twine not installed.\nUse `pip install twine`.\nExiting.")
sys.exit()
@@ -33,48 +33,48 @@ if sys.argv[-1] == 'publish':
print("You probably want to also tag the version now:")
print(" git tag -a %s -m 'version %s'" % (version, version))
print(" git push --tags")
- shutil.rmtree('dist')
- shutil.rmtree('build')
- shutil.rmtree('djangorestframework.egg-info')
+ shutil.rmtree("dist")
+ shutil.rmtree("build")
+ shutil.rmtree("djangorestframework.egg-info")
sys.exit()
setup(
- name='djangorestframework',
+ name="djangorestframework",
version=version,
- url='https://www.django-rest-framework.org/',
- license='BSD',
- description='Web APIs for Django, made easy.',
- long_description=read('README.md'),
- long_description_content_type='text/markdown',
- author='Tom Christie',
- author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
- packages=find_packages(exclude=['tests*']),
+ url="https://www.django-rest-framework.org/",
+ license="BSD",
+ description="Web APIs for Django, made easy.",
+ long_description=read("README.md"),
+ long_description_content_type="text/markdown",
+ author="Tom Christie",
+ author_email="tom@tomchristie.com", # SEE NOTE BELOW (*)
+ packages=find_packages(exclude=["tests*"]),
include_package_data=True,
install_requires=[],
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
zip_safe=False,
classifiers=[
- 'Development Status :: 5 - Production/Stable',
- 'Environment :: Web Environment',
- 'Framework :: Django',
- 'Framework :: Django :: 1.11',
- 'Framework :: Django :: 2.0',
- 'Framework :: Django :: 2.1',
- 'Framework :: Django :: 2.2',
- 'Intended Audience :: Developers',
- 'License :: OSI Approved :: BSD License',
- 'Operating System :: OS Independent',
- 'Programming Language :: Python',
- 'Programming Language :: Python :: 2',
- 'Programming Language :: Python :: 2.7',
- 'Programming Language :: Python :: 3',
- 'Programming Language :: Python :: 3.4',
- 'Programming Language :: Python :: 3.5',
- 'Programming Language :: Python :: 3.6',
- 'Programming Language :: Python :: 3.7',
- 'Topic :: Internet :: WWW/HTTP',
- ]
+ "Development Status :: 5 - Production/Stable",
+ "Environment :: Web Environment",
+ "Framework :: Django",
+ "Framework :: Django :: 1.11",
+ "Framework :: Django :: 2.0",
+ "Framework :: Django :: 2.1",
+ "Framework :: Django :: 2.2",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: BSD License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Programming Language :: Python :: 2",
+ "Programming Language :: Python :: 2.7",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.4",
+ "Programming Language :: Python :: 3.5",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Topic :: Internet :: WWW/HTTP",
+ ],
)
# (*) Please direct queries to the discussion group, rather than to me directly
diff --git a/tests/authentication/migrations/0001_initial.py b/tests/authentication/migrations/0001_initial.py
index cfc887240..774b23316 100644
--- a/tests/authentication/migrations/0001_initial.py
+++ b/tests/authentication/migrations/0001_initial.py
@@ -9,16 +9,22 @@ class Migration(migrations.Migration):
initial = True
- dependencies = [
- migrations.swappable_dependency(settings.AUTH_USER_MODEL),
- ]
+ dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
operations = [
migrations.CreateModel(
- name='CustomToken',
+ name="CustomToken",
fields=[
- ('key', models.CharField(max_length=40, primary_key=True, serialize=False)),
- ('user', models.OneToOneField(on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL)),
+ (
+ "key",
+ models.CharField(max_length=40, primary_key=True, serialize=False),
+ ),
+ (
+ "user",
+ models.OneToOneField(
+ on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL
+ ),
+ ),
],
- ),
+ )
]
diff --git a/tests/authentication/test_authentication.py b/tests/authentication/test_authentication.py
index 793773542..dfbfdf0e1 100644
--- a/tests/authentication/test_authentication.py
+++ b/tests/authentication/test_authentication.py
@@ -13,11 +13,18 @@ from django.test import TestCase, override_settings
from django.utils import six
from rest_framework import (
- HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
+ HTTP_HEADER_ENCODING,
+ exceptions,
+ permissions,
+ renderers,
+ status,
)
from rest_framework.authentication import (
- BaseAuthentication, BasicAuthentication, RemoteUserAuthentication,
- SessionAuthentication, TokenAuthentication
+ BaseAuthentication,
+ BasicAuthentication,
+ RemoteUserAuthentication,
+ SessionAuthentication,
+ TokenAuthentication,
)
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token
@@ -27,6 +34,7 @@ from rest_framework.views import APIView
from .models import CustomToken
+
factory = APIRequestFactory()
@@ -35,92 +43,77 @@ class CustomTokenAuthentication(TokenAuthentication):
class CustomKeywordTokenAuthentication(TokenAuthentication):
- keyword = 'Bearer'
+ keyword = "Bearer"
class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)
def get(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+ return HttpResponse({"a": 1, "b": 2, "c": 3})
def post(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+ return HttpResponse({"a": 1, "b": 2, "c": 3})
def put(self, request):
- return HttpResponse({'a': 1, 'b': 2, 'c': 3})
+ return HttpResponse({"a": 1, "b": 2, "c": 3})
urlpatterns = [
url(
- r'^session/$',
- MockView.as_view(authentication_classes=[SessionAuthentication])
+ r"^session/$", MockView.as_view(authentication_classes=[SessionAuthentication])
+ ),
+ url(r"^basic/$", MockView.as_view(authentication_classes=[BasicAuthentication])),
+ url(
+ r"^remote-user/$",
+ MockView.as_view(authentication_classes=[RemoteUserAuthentication]),
+ ),
+ url(r"^token/$", MockView.as_view(authentication_classes=[TokenAuthentication])),
+ url(
+ r"^customtoken/$",
+ MockView.as_view(authentication_classes=[CustomTokenAuthentication]),
),
url(
- r'^basic/$',
- MockView.as_view(authentication_classes=[BasicAuthentication])
+ r"^customkeywordtoken/$",
+ MockView.as_view(authentication_classes=[CustomKeywordTokenAuthentication]),
),
- url(
- r'^remote-user/$',
- MockView.as_view(authentication_classes=[RemoteUserAuthentication])
- ),
- url(
- r'^token/$',
- MockView.as_view(authentication_classes=[TokenAuthentication])
- ),
- url(
- r'^customtoken/$',
- MockView.as_view(authentication_classes=[CustomTokenAuthentication])
- ),
- url(
- r'^customkeywordtoken/$',
- MockView.as_view(
- authentication_classes=[CustomKeywordTokenAuthentication]
- )
- ),
- url(r'^auth-token/$', obtain_auth_token),
- url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
+ url(r"^auth-token/$", obtain_auth_token),
+ url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
]
@override_settings(ROOT_URLCONF=__name__)
class BasicAuthTests(TestCase):
"""Basic authentication"""
+
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username, self.email, self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
+ credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
+ auth = "Basic %s" % base64_credentials
response = self.csrf_client.post(
- '/basic/',
- {'example': 'example'},
- HTTP_AUTHORIZATION=auth
+ "/basic/", {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
- credentials = ('%s:%s' % (self.username, self.password))
+ credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING)
- auth = 'Basic %s' % base64_credentials
+ auth = "Basic %s" % base64_credentials
response = self.csrf_client.post(
- '/basic/',
- {'example': 'example'},
- format='json',
- HTTP_AUTHORIZATION=auth
+ "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
@@ -128,39 +121,34 @@ class BasicAuthTests(TestCase):
"""Ensure POSTing JSON over basic auth with incorrectly padded Base64 string is handled correctly"""
# regression test for issue in 'rest_framework.authentication.BasicAuthentication.authenticate'
# https://github.com/encode/django-rest-framework/issues/4089
- auth = 'Basic =a='
+ auth = "Basic =a="
response = self.csrf_client.post(
- '/basic/',
- {'example': 'example'},
- format='json',
- HTTP_AUTHORIZATION=auth
+ "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
- response = self.csrf_client.post('/basic/', {'example': 'example'})
+ response = self.csrf_client.post("/basic/", {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post(
- '/basic/',
- {'example': 'example'},
- format='json'
+ "/basic/", {"example": "example"}, format="json"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
- assert response['WWW-Authenticate'] == 'Basic realm="api"'
+ assert response["WWW-Authenticate"] == 'Basic realm="api"'
def test_fail_post_if_credentials_are_missing(self):
response = self.csrf_client.post(
- '/basic/', {'example': 'example'}, HTTP_AUTHORIZATION='Basic ')
+ "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic "
+ )
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_credentials_contain_spaces(self):
response = self.csrf_client.post(
- '/basic/', {'example': 'example'},
- HTTP_AUTHORIZATION='Basic foo bar'
+ "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic foo bar"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -168,15 +156,14 @@ class BasicAuthTests(TestCase):
@override_settings(ROOT_URLCONF=__name__)
class SessionAuthTests(TestCase):
"""User session authentication"""
+
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.non_csrf_client = APIClient(enforce_csrf_checks=False)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username, self.email, self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.csrf_client.logout()
@@ -187,8 +174,8 @@ class SessionAuthTests(TestCase):
cf. [#1810](https://github.com/encode/django-rest-framework/pull/1810)
"""
- response = self.csrf_client.get('/auth/login/')
- content = response.content.decode('utf8')
+ response = self.csrf_client.get("/auth/login/")
+ content = response.content.decode("utf8")
assert '' in content
def test_post_form_session_auth_failing_csrf(self):
@@ -196,7 +183,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
- response = self.csrf_client.post('/session/', {'example': 'example'})
+ response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_post_form_session_auth_passing_csrf(self):
@@ -213,10 +200,9 @@ class SessionAuthTests(TestCase):
self.csrf_client.cookies[settings.CSRF_COOKIE_NAME] = token
# Post the token matching the cookie value
- response = self.csrf_client.post('/session/', {
- 'example': 'example',
- 'csrfmiddlewaretoken': token,
- })
+ response = self.csrf_client.post(
+ "/session/", {"example": "example", "csrfmiddlewaretoken": token}
+ )
assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_passing(self):
@@ -224,12 +210,8 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in
user and CSRF token passes.
"""
- self.non_csrf_client.login(
- username=self.username, password=self.password
- )
- response = self.non_csrf_client.post(
- '/session/', {'example': 'example'}
- )
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_200_OK
def test_put_form_session_auth_passing(self):
@@ -237,38 +219,33 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with
logged in user and CSRF token passes.
"""
- self.non_csrf_client.login(
- username=self.username, password=self.password
- )
- response = self.non_csrf_client.put(
- '/session/', {'example': 'example'}
- )
+ self.non_csrf_client.login(username=self.username, password=self.password)
+ response = self.non_csrf_client.put("/session/", {"example": "example"})
assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
- response = self.csrf_client.post('/session/', {'example': 'example'})
+ response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN
class BaseTokenAuthTests(object):
"""Token authentication"""
+
model = None
path = None
- header_prefix = 'Token '
+ header_prefix = "Token "
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username, self.email, self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
- self.key = 'abcd1234'
+ self.key = "abcd1234"
self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self):
@@ -278,39 +255,41 @@ class BaseTokenAuthTests(object):
"""
auth = self.header_prefix + self.key
response = self.csrf_client.post(
- self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
+ self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
def test_fail_authentication_if_user_is_not_active(self):
- user = User.objects.create_user('foo', 'bar', 'baz')
+ user = User.objects.create_user("foo", "bar", "baz")
user.is_active = False
user.save()
- self.model.objects.create(key='foobar_token', user=user)
+ self.model.objects.create(key="foobar_token", user=user)
response = self.csrf_client.post(
- self.path, {'example': 'example'},
- HTTP_AUTHORIZATION=self.header_prefix + 'foobar_token'
+ self.path,
+ {"example": "example"},
+ HTTP_AUTHORIZATION=self.header_prefix + "foobar_token",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
- auth = self.header_prefix + 'wxyz6789'
+ auth = self.header_prefix + "wxyz6789"
response = self.csrf_client.post(
- self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
+ self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_is_missing(self):
response = self.csrf_client.post(
- self.path, {'example': 'example'},
- HTTP_AUTHORIZATION=self.header_prefix)
+ self.path, {"example": "example"}, HTTP_AUTHORIZATION=self.header_prefix
+ )
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_contains_spaces(self):
response = self.csrf_client.post(
- self.path, {'example': 'example'},
- HTTP_AUTHORIZATION=self.header_prefix + 'foo bar'
+ self.path,
+ {"example": "example"},
+ HTTP_AUTHORIZATION=self.header_prefix + "foo bar",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -318,7 +297,7 @@ class BaseTokenAuthTests(object):
# add an 'invalid' unicode character
auth = self.header_prefix + self.key + "¸"
response = self.csrf_client.post(
- self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
+ self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -329,8 +308,7 @@ class BaseTokenAuthTests(object):
"""
auth = self.header_prefix + self.key
response = self.csrf_client.post(
- self.path, {'example': 'example'},
- format='json', HTTP_AUTHORIZATION=auth
+ self.path, {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
@@ -343,8 +321,10 @@ class BaseTokenAuthTests(object):
def func_to_test():
return self.csrf_client.post(
- self.path, {'example': 'example'},
- format='json', HTTP_AUTHORIZATION=auth
+ self.path,
+ {"example": "example"},
+ format="json",
+ HTTP_AUTHORIZATION=auth,
)
self.assertNumQueries(1, func_to_test)
@@ -353,7 +333,7 @@ class BaseTokenAuthTests(object):
"""
Ensure POSTing form over token auth without correct credentials fails
"""
- response = self.csrf_client.post(self.path, {'example': 'example'})
+ response = self.csrf_client.post(self.path, {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_token_auth(self):
@@ -361,7 +341,7 @@ class BaseTokenAuthTests(object):
Ensure POSTing json over token auth without correct credentials fails
"""
response = self.csrf_client.post(
- self.path, {'example': 'example'}, format='json'
+ self.path, {"example": "example"}, format="json"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@@ -369,7 +349,7 @@ class BaseTokenAuthTests(object):
@override_settings(ROOT_URLCONF=__name__)
class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
- path = '/token/'
+ path = "/token/"
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
@@ -387,12 +367,12 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
"""Ensure token login view using JSON POST works."""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
- '/auth-token/',
- {'username': self.username, 'password': self.password},
- format='json'
+ "/auth-token/",
+ {"username": self.username, "password": self.password},
+ format="json",
)
assert response.status_code == status.HTTP_200_OK
- assert response.data['token'] == self.key
+ assert response.data["token"] == self.key
def test_token_login_json_bad_creds(self):
"""
@@ -401,41 +381,41 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
"""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
- '/auth-token/',
- {'username': self.username, 'password': "badpass"},
- format='json'
+ "/auth-token/",
+ {"username": self.username, "password": "badpass"},
+ format="json",
)
assert response.status_code == 400
def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields."""
client = APIClient(enforce_csrf_checks=True)
- response = client.post('/auth-token/',
- {'username': self.username}, format='json')
+ response = client.post(
+ "/auth-token/", {"username": self.username}, format="json"
+ )
assert response.status_code == 400
def test_token_login_form(self):
"""Ensure token login view using form POST works."""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
- '/auth-token/',
- {'username': self.username, 'password': self.password}
+ "/auth-token/", {"username": self.username, "password": self.password}
)
assert response.status_code == status.HTTP_200_OK
- assert response.data['token'] == self.key
+ assert response.data["token"] == self.key
@override_settings(ROOT_URLCONF=__name__)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
- path = '/customtoken/'
+ path = "/customtoken/"
@override_settings(ROOT_URLCONF=__name__)
class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
- path = '/customkeywordtoken/'
- header_prefix = 'Bearer '
+ path = "/customkeywordtoken/"
+ header_prefix = "Bearer "
class IncorrectCredentialsTests(TestCase):
@@ -445,42 +425,42 @@ class IncorrectCredentialsTests(TestCase):
authentication should run and error, even if no permissions
are set on the view.
"""
+
class IncorrectCredentialsAuth(BaseAuthentication):
def authenticate(self, request):
- raise exceptions.AuthenticationFailed('Bad credentials')
+ raise exceptions.AuthenticationFailed("Bad credentials")
- request = factory.get('/')
+ request = factory.get("/")
view = MockView.as_view(
- authentication_classes=(IncorrectCredentialsAuth,),
- permission_classes=()
+ authentication_classes=(IncorrectCredentialsAuth,), permission_classes=()
)
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
- assert response.data == {'detail': 'Bad credentials'}
+ assert response.data == {"detail": "Bad credentials"}
class FailingAuthAccessedInRenderer(TestCase):
def setUp(self):
class AuthAccessingRenderer(renderers.BaseRenderer):
- media_type = 'text/plain'
- format = 'txt'
+ media_type = "text/plain"
+ format = "txt"
def render(self, data, media_type=None, renderer_context=None):
- request = renderer_context['request']
+ request = renderer_context["request"]
if request.user.is_authenticated:
- return b'authenticated'
- return b'not authenticated'
+ return b"authenticated"
+ return b"not authenticated"
class FailingAuth(BaseAuthentication):
def authenticate(self, request):
- raise exceptions.AuthenticationFailed('authentication failed')
+ raise exceptions.AuthenticationFailed("authentication failed")
class ExampleView(APIView):
authentication_classes = (FailingAuth,)
renderer_classes = (AuthAccessingRenderer,)
def get(self, request):
- return Response({'foo': 'bar'})
+ return Response({"foo": "bar"})
self.view = ExampleView.as_view()
@@ -490,10 +470,10 @@ class FailingAuthAccessedInRenderer(TestCase):
`request.user` without raising an exception. Particularly relevant
to HTML responses that might reasonably access `request.user`.
"""
- request = factory.get('/')
+ request = factory.get("/")
response = self.view(request)
content = response.render().content
- assert content == b'not authenticated'
+ assert content == b"not authenticated"
class NoAuthenticationClassesTests(TestCase):
@@ -505,23 +485,21 @@ class NoAuthenticationClassesTests(TestCase):
"""
class DummyPermission(permissions.BasePermission):
- message = 'Dummy permission message'
+ message = "Dummy permission message"
def has_permission(self, request, view):
return False
- request = factory.get('/')
+ request = factory.get("/")
view = MockView.as_view(
- authentication_classes=(),
- permission_classes=(DummyPermission,),
+ authentication_classes=(), permission_classes=(DummyPermission,)
)
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
- assert response.data == {'detail': 'Dummy permission message'}
+ assert response.data == {"detail": "Dummy permission message"}
class BasicAuthenticationUnitTests(TestCase):
-
def test_base_authentication_abstract_method(self):
with pytest.raises(NotImplementedError):
BaseAuthentication().authenticate({})
@@ -529,34 +507,34 @@ class BasicAuthenticationUnitTests(TestCase):
def test_basic_authentication_raises_error_if_user_not_found(self):
auth = BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed):
- auth.authenticate_credentials('invalid id', 'invalid password')
+ auth.authenticate_credentials("invalid id", "invalid password")
def test_basic_authentication_raises_error_if_user_not_active(self):
from rest_framework import authentication
class MockUser(object):
is_active = False
+
old_authenticate = authentication.authenticate
authentication.authenticate = lambda **kwargs: MockUser()
auth = authentication.BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed) as error:
- auth.authenticate_credentials('foo', 'bar')
- assert 'User inactive or deleted.' in str(error)
+ auth.authenticate_credentials("foo", "bar")
+ assert "User inactive or deleted." in str(error)
authentication.authenticate = old_authenticate
-@override_settings(ROOT_URLCONF=__name__,
- AUTHENTICATION_BACKENDS=('django.contrib.auth.backends.RemoteUserBackend',))
+@override_settings(
+ ROOT_URLCONF=__name__,
+ AUTHENTICATION_BACKENDS=("django.contrib.auth.backends.RemoteUserBackend",),
+)
class RemoteUserAuthenticationUnitTests(TestCase):
def setUp(self):
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username, self.email, self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def test_remote_user_works(self):
- response = self.client.post('/remote-user/',
- REMOTE_USER=self.username)
+ response = self.client.post("/remote-user/", REMOTE_USER=self.username)
self.assertEqual(response.status_code, status.HTTP_200_OK)
diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py
index 0e9379717..8dd78aef1 100644
--- a/tests/browsable_api/auth_urls.py
+++ b/tests/browsable_api/auth_urls.py
@@ -4,7 +4,8 @@ from django.conf.urls import include, url
from .views import MockView
+
urlpatterns = [
- url(r'^$', MockView.as_view()),
- url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
+ url(r"^$", MockView.as_view()),
+ url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
]
diff --git a/tests/browsable_api/no_auth_urls.py b/tests/browsable_api/no_auth_urls.py
index 5fc95c727..505e9a762 100644
--- a/tests/browsable_api/no_auth_urls.py
+++ b/tests/browsable_api/no_auth_urls.py
@@ -4,6 +4,5 @@ from django.conf.urls import url
from .views import MockView
-urlpatterns = [
- url(r'^$', MockView.as_view()),
-]
+
+urlpatterns = [url(r"^$", MockView.as_view())]
diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py
index 684d7ae14..72e2fd1de 100644
--- a/tests/browsable_api/test_browsable_api.py
+++ b/tests/browsable_api/test_browsable_api.py
@@ -6,71 +6,65 @@ from django.test import TestCase, override_settings
from rest_framework.test import APIClient
-@override_settings(ROOT_URLCONF='tests.browsable_api.auth_urls')
+@override_settings(ROOT_URLCONF="tests.browsable_api.auth_urls")
class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled."""
+
def setUp(self):
self.client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username,
- self.email,
- self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.client.logout()
def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
- response = self.client.get('/')
- content = response.content.decode('utf8')
- assert 'john' in content
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
+ assert "john" in content
def test_logout_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
- response = self.client.get('/')
- content = response.content.decode('utf8')
- assert '>Log out<' in content
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
+ assert ">Log out<" in content
def test_login_shown_when_logged_out(self):
- response = self.client.get('/')
- content = response.content.decode('utf8')
- assert '>Log in<' in content
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
+ assert ">Log in<" in content
-@override_settings(ROOT_URLCONF='tests.browsable_api.no_auth_urls')
+@override_settings(ROOT_URLCONF="tests.browsable_api.no_auth_urls")
class NoDropdownWithoutAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views NOT enabled."""
+
def setUp(self):
self.client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username,
- self.email,
- self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.client.logout()
def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
- response = self.client.get('/')
- content = response.content.decode('utf8')
- assert 'john' in content
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
+ assert "john" in content
def test_dropdown_not_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
- response = self.client.get('/')
- content = response.content.decode('utf8')
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
assert '' not in content
def test_dropdown_not_shown_when_logged_out(self):
- response = self.client.get('/')
- content = response.content.decode('utf8')
+ response = self.client.get("/")
+ content = response.content.decode("utf8")
assert '' not in content
diff --git a/tests/browsable_api/test_browsable_nested_api.py b/tests/browsable_api/test_browsable_nested_api.py
index 8f38b3c4e..23925f42f 100644
--- a/tests/browsable_api/test_browsable_nested_api.py
+++ b/tests/browsable_api/test_browsable_nested_api.py
@@ -19,24 +19,22 @@ class NestedSerializerTestSerializer(serializers.Serializer):
class NestedSerializersView(ListCreateAPIView):
- renderer_classes = (BrowsableAPIRenderer, )
+ renderer_classes = (BrowsableAPIRenderer,)
serializer_class = NestedSerializerTestSerializer
- queryset = [{'nested': {'one': 1, 'two': 2}}]
+ queryset = [{"nested": {"one": 1, "two": 2}}]
-urlpatterns = [
- url(r'^api/$', NestedSerializersView.as_view(), name='api'),
-]
+urlpatterns = [url(r"^api/$", NestedSerializersView.as_view(), name="api")]
class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled."""
- @override_settings(ROOT_URLCONF='tests.browsable_api.test_browsable_nested_api')
+ @override_settings(ROOT_URLCONF="tests.browsable_api.test_browsable_nested_api")
def test_login(self):
- response = self.client.get('/api/')
+ response = self.client.get("/api/")
assert 200 == response.status_code
- content = response.content.decode('utf-8')
+ content = response.content.decode("utf-8")
assert 'form action="/api/"' in content
assert 'input name="nested.one"' in content
assert 'input name="nested.two"' in content
diff --git a/tests/browsable_api/test_form_rendering.py b/tests/browsable_api/test_form_rendering.py
index d8378a2ca..29469516d 100644
--- a/tests/browsable_api/test_form_rendering.py
+++ b/tests/browsable_api/test_form_rendering.py
@@ -5,13 +5,14 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from tests.models import BasicModel
+
factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
- fields = '__all__'
+ fields = "__all__"
class StandardPostView(generics.CreateAPIView):
@@ -39,19 +40,19 @@ class TestPostingListData(TestCase):
def test_json_response(self):
# sanity check for non-browsable API responses
view = StandardPostView.as_view()
- request = factory.post('/', [{}], format='json')
+ request = factory.post("/", [{}], format="json")
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertTrue('non_field_errors' in response.data)
+ self.assertTrue("non_field_errors" in response.data)
def test_browsable_api(self):
view = StandardPostView.as_view()
- request = factory.post('/?format=api', [{}], format='json')
+ request = factory.post("/?format=api", [{}], format="json")
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
- self.assertTrue('non_field_errors' in response.data)
+ self.assertTrue("non_field_errors" in response.data)
class TestManyPostView(TestCase):
@@ -59,14 +60,11 @@ class TestManyPostView(TestCase):
"""
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = ManyPostView.as_view()
def test_post_many_post_view(self):
@@ -77,7 +75,7 @@ class TestManyPostView(TestCase):
Regression test for https://github.com/encode/django-rest-framework/pull/3164
"""
data = {}
- request = factory.post('/', data, format='json')
+ request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
diff --git a/tests/browsable_api/views.py b/tests/browsable_api/views.py
index 03758f10b..926f55f83 100644
--- a/tests/browsable_api/views.py
+++ b/tests/browsable_api/views.py
@@ -10,4 +10,4 @@ class MockView(APIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get(self, request):
- return Response({'a': 1, 'b': 2, 'c': 3})
+ return Response({"a": 1, "b": 2, "c": 3})
diff --git a/tests/conftest.py b/tests/conftest.py
index ac29e4a42..89ba66f50 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -6,13 +6,21 @@ from django.core import management
def pytest_addoption(parser):
- parser.addoption('--no-pkgroot', action='store_true', default=False,
- help='Remove package root directory from sys.path, ensuring that '
- 'rest_framework is imported from the installed site-packages. '
- 'Used for testing the distribution.')
- parser.addoption('--staticfiles', action='store_true', default=False,
- help='Run tests with static files collection, using manifest '
- 'staticfiles storage. Used for testing the distribution.')
+ parser.addoption(
+ "--no-pkgroot",
+ action="store_true",
+ default=False,
+ help="Remove package root directory from sys.path, ensuring that "
+ "rest_framework is imported from the installed site-packages. "
+ "Used for testing the distribution.",
+ )
+ parser.addoption(
+ "--staticfiles",
+ action="store_true",
+ default=False,
+ help="Run tests with static files collection, using manifest "
+ "staticfiles storage. Used for testing the distribution.",
+ )
def pytest_configure(config):
@@ -21,49 +29,42 @@ def pytest_configure(config):
settings.configure(
DEBUG_PROPAGATE_EXCEPTIONS=True,
DATABASES={
- 'default': {
- 'ENGINE': 'django.db.backends.sqlite3',
- 'NAME': ':memory:'
- }
+ "default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}
},
SITE_ID=1,
- SECRET_KEY='not very secret in tests',
+ SECRET_KEY="not very secret in tests",
USE_I18N=True,
USE_L10N=True,
- STATIC_URL='/static/',
- ROOT_URLCONF='tests.urls',
+ STATIC_URL="/static/",
+ ROOT_URLCONF="tests.urls",
TEMPLATES=[
{
- 'BACKEND': 'django.template.backends.django.DjangoTemplates',
- 'APP_DIRS': True,
- 'OPTIONS': {
- "debug": True, # We want template errors to raise
- }
- },
+ "BACKEND": "django.template.backends.django.DjangoTemplates",
+ "APP_DIRS": True,
+ "OPTIONS": {"debug": True}, # We want template errors to raise
+ }
],
MIDDLEWARE=(
- 'django.middleware.common.CommonMiddleware',
- 'django.contrib.sessions.middleware.SessionMiddleware',
- 'django.contrib.auth.middleware.AuthenticationMiddleware',
- 'django.contrib.messages.middleware.MessageMiddleware',
+ "django.middleware.common.CommonMiddleware",
+ "django.contrib.sessions.middleware.SessionMiddleware",
+ "django.contrib.auth.middleware.AuthenticationMiddleware",
+ "django.contrib.messages.middleware.MessageMiddleware",
),
INSTALLED_APPS=(
- 'django.contrib.admin',
- 'django.contrib.auth',
- 'django.contrib.contenttypes',
- 'django.contrib.sessions',
- 'django.contrib.sites',
- 'django.contrib.staticfiles',
- 'rest_framework',
- 'rest_framework.authtoken',
- 'tests.authentication',
- 'tests.generic_relations',
- 'tests.importable',
- 'tests',
- ),
- PASSWORD_HASHERS=(
- 'django.contrib.auth.hashers.MD5PasswordHasher',
+ "django.contrib.admin",
+ "django.contrib.auth",
+ "django.contrib.contenttypes",
+ "django.contrib.sessions",
+ "django.contrib.sites",
+ "django.contrib.staticfiles",
+ "rest_framework",
+ "rest_framework.authtoken",
+ "tests.authentication",
+ "tests.generic_relations",
+ "tests.importable",
+ "tests",
),
+ PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",),
)
# guardian is optional
@@ -74,28 +75,32 @@ def pytest_configure(config):
else:
settings.ANONYMOUS_USER_ID = -1
settings.AUTHENTICATION_BACKENDS = (
- 'django.contrib.auth.backends.ModelBackend',
- 'guardian.backends.ObjectPermissionBackend',
- )
- settings.INSTALLED_APPS += (
- 'guardian',
+ "django.contrib.auth.backends.ModelBackend",
+ "guardian.backends.ObjectPermissionBackend",
)
+ settings.INSTALLED_APPS += ("guardian",)
- if config.getoption('--no-pkgroot'):
+ if config.getoption("--no-pkgroot"):
sys.path.pop(0)
# import rest_framework before pytest re-adds the package root directory.
import rest_framework
- package_dir = os.path.join(os.getcwd(), 'rest_framework')
+
+ package_dir = os.path.join(os.getcwd(), "rest_framework")
assert not rest_framework.__file__.startswith(package_dir)
# Manifest storage will raise an exception if static files are not present (ie, a packaging failure).
- if config.getoption('--staticfiles'):
+ if config.getoption("--staticfiles"):
import rest_framework
- settings.STATIC_ROOT = os.path.join(os.path.dirname(rest_framework.__file__), 'static-root')
- settings.STATICFILES_STORAGE = 'django.contrib.staticfiles.storage.ManifestStaticFilesStorage'
+
+ settings.STATIC_ROOT = os.path.join(
+ os.path.dirname(rest_framework.__file__), "static-root"
+ )
+ settings.STATICFILES_STORAGE = (
+ "django.contrib.staticfiles.storage.ManifestStaticFilesStorage"
+ )
django.setup()
- if config.getoption('--staticfiles'):
- management.call_command('collectstatic', verbosity=0, interactive=False)
+ if config.getoption("--staticfiles"):
+ management.call_command("collectstatic", verbosity=0, interactive=False)
diff --git a/tests/generic_relations/migrations/0001_initial.py b/tests/generic_relations/migrations/0001_initial.py
index ea04d8d67..e1c4259b9 100644
--- a/tests/generic_relations/migrations/0001_initial.py
+++ b/tests/generic_relations/migrations/0001_initial.py
@@ -5,32 +5,59 @@ class Migration(migrations.Migration):
initial = True
- dependencies = [
- ('contenttypes', '0002_remove_content_type_name'),
- ]
+ dependencies = [("contenttypes", "0002_remove_content_type_name")]
operations = [
migrations.CreateModel(
- name='Bookmark',
+ name="Bookmark",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('url', models.URLField()),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("url", models.URLField()),
],
),
migrations.CreateModel(
- name='Note',
+ name="Note",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('text', models.TextField()),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("text", models.TextField()),
],
),
migrations.CreateModel(
- name='Tag',
+ name="Tag",
fields=[
- ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
- ('tag', models.SlugField()),
- ('object_id', models.PositiveIntegerField()),
- ('content_type', models.ForeignKey(on_delete=models.CASCADE, to='contenttypes.ContentType')),
+ (
+ "id",
+ models.AutoField(
+ auto_created=True,
+ primary_key=True,
+ serialize=False,
+ verbose_name="ID",
+ ),
+ ),
+ ("tag", models.SlugField()),
+ ("object_id", models.PositiveIntegerField()),
+ (
+ "content_type",
+ models.ForeignKey(
+ on_delete=models.CASCADE, to="contenttypes.ContentType"
+ ),
+ ),
],
),
]
diff --git a/tests/generic_relations/models.py b/tests/generic_relations/models.py
index 55bc243cb..1840db341 100644
--- a/tests/generic_relations/models.py
+++ b/tests/generic_relations/models.py
@@ -1,8 +1,6 @@
from __future__ import unicode_literals
-from django.contrib.contenttypes.fields import (
- GenericForeignKey, GenericRelation
-)
+from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
@@ -13,10 +11,11 @@ class Tag(models.Model):
"""
Tags have a descriptive slug, and are attached to an arbitrary object.
"""
+
tag = models.SlugField()
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.PositiveIntegerField()
- tagged_item = GenericForeignKey('content_type', 'object_id')
+ tagged_item = GenericForeignKey("content_type", "object_id")
def __str__(self):
return self.tag
@@ -27,11 +26,12 @@ class Bookmark(models.Model):
"""
A URL bookmark that may have multiple tags attached.
"""
+
url = models.URLField()
tags = GenericRelation(Tag)
def __str__(self):
- return 'Bookmark: %s' % self.url
+ return "Bookmark: %s" % self.url
@python_2_unicode_compatible
@@ -39,8 +39,9 @@ class Note(models.Model):
"""
A textual note that may have multiple tags attached.
"""
+
text = models.TextField()
tags = GenericRelation(Tag)
def __str__(self):
- return 'Note: %s' % self.text
+ return "Note: %s" % self.text
diff --git a/tests/generic_relations/test_generic_relations.py b/tests/generic_relations/test_generic_relations.py
index c8de332e1..251a6236b 100644
--- a/tests/generic_relations/test_generic_relations.py
+++ b/tests/generic_relations/test_generic_relations.py
@@ -9,11 +9,11 @@ from .models import Bookmark, Note, Tag
class TestGenericRelations(TestCase):
def setUp(self):
- self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
- Tag.objects.create(tagged_item=self.bookmark, tag='django')
- Tag.objects.create(tagged_item=self.bookmark, tag='python')
- self.note = Note.objects.create(text='Remember the milk')
- Tag.objects.create(tagged_item=self.note, tag='reminder')
+ self.bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/")
+ Tag.objects.create(tagged_item=self.bookmark, tag="django")
+ Tag.objects.create(tagged_item=self.bookmark, tag="python")
+ self.note = Note.objects.create(text="Remember the milk")
+ Tag.objects.create(tagged_item=self.note, tag="reminder")
def test_generic_relation(self):
"""
@@ -26,12 +26,12 @@ class TestGenericRelations(TestCase):
class Meta:
model = Bookmark
- fields = ('tags', 'url')
+ fields = ("tags", "url")
serializer = BookmarkSerializer(self.bookmark)
expected = {
- 'tags': ['django', 'python'],
- 'url': 'https://www.djangoproject.com/'
+ "tags": ["django", "python"],
+ "url": "https://www.djangoproject.com/",
}
assert serializer.data == expected
@@ -46,21 +46,18 @@ class TestGenericRelations(TestCase):
class Meta:
model = Tag
- fields = ('tag', 'tagged_item')
+ fields = ("tag", "tagged_item")
serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [
{
- 'tag': 'django',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ "tag": "django",
+ "tagged_item": "Bookmark: https://www.djangoproject.com/",
},
{
- 'tag': 'python',
- 'tagged_item': 'Bookmark: https://www.djangoproject.com/'
+ "tag": "python",
+ "tagged_item": "Bookmark: https://www.djangoproject.com/",
},
- {
- 'tag': 'reminder',
- 'tagged_item': 'Note: Remember the milk'
- }
+ {"tag": "reminder", "tagged_item": "Note: Remember the milk"},
]
assert serializer.data == expected
diff --git a/tests/importable/test_installed.py b/tests/importable/test_installed.py
index 072d3b2e4..8130a79ff 100644
--- a/tests/importable/test_installed.py
+++ b/tests/importable/test_installed.py
@@ -5,9 +5,9 @@ from tests import importable
def test_installed():
# ensure that apps can freely import rest_framework.compat
- assert 'tests.importable' in settings.INSTALLED_APPS
+ assert "tests.importable" in settings.INSTALLED_APPS
def test_imported():
# ensure that the __init__ hasn't been mucked with
- assert hasattr(importable, 'compat')
+ assert hasattr(importable, "compat")
diff --git a/tests/models.py b/tests/models.py
index 17bf23cda..0f6628a28 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -12,7 +12,7 @@ class RESTFrameworkModel(models.Model):
"""
class Meta:
- app_label = 'tests'
+ app_label = "tests"
abstract = True
@@ -20,7 +20,7 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(
max_length=100,
verbose_name=_("Text comes here"),
- help_text=_("Text description.")
+ help_text=_("Text description."),
)
@@ -32,7 +32,7 @@ class ManyToManyTarget(RESTFrameworkModel):
class ManyToManySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
- targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
+ targets = models.ManyToManyField(ManyToManyTarget, related_name="sources")
# ForeignKey
@@ -47,51 +47,74 @@ class UUIDForeignKeyTarget(RESTFrameworkModel):
class ForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, related_name='sources',
- help_text='Target', verbose_name='Target',
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ ForeignKeyTarget,
+ related_name="sources",
+ help_text="Target",
+ verbose_name="Target",
+ on_delete=models.CASCADE,
+ )
class ForeignKeySourceWithLimitedChoices(RESTFrameworkModel):
- target = models.ForeignKey(ForeignKeyTarget, help_text='Target',
- verbose_name='Target',
- limit_choices_to={"name__startswith": "limited-"},
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ ForeignKeyTarget,
+ help_text="Target",
+ verbose_name="Target",
+ limit_choices_to={"name__startswith": "limited-"},
+ on_delete=models.CASCADE,
+ )
class ForeignKeySourceWithQLimitedChoices(RESTFrameworkModel):
- target = models.ForeignKey(ForeignKeyTarget, help_text='Target',
- verbose_name='Target',
- limit_choices_to=models.Q(name__startswith="limited-"),
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ ForeignKeyTarget,
+ help_text="Target",
+ verbose_name="Target",
+ limit_choices_to=models.Q(name__startswith="limited-"),
+ on_delete=models.CASCADE,
+ )
# Nullable ForeignKey
class NullableForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources',
- verbose_name='Optional target object',
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ ForeignKeyTarget,
+ null=True,
+ blank=True,
+ related_name="nullable_sources",
+ verbose_name="Optional target object",
+ on_delete=models.CASCADE,
+ )
class NullableUUIDForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
- target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
- related_name='nullable_sources',
- verbose_name='Optional target object',
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ ForeignKeyTarget,
+ null=True,
+ blank=True,
+ related_name="nullable_sources",
+ verbose_name="Optional target object",
+ on_delete=models.CASCADE,
+ )
class NestedForeignKeySource(RESTFrameworkModel):
"""
Used for testing FK chain. A -> B -> C.
"""
+
name = models.CharField(max_length=100)
- target = models.ForeignKey(NullableForeignKeySource, null=True, blank=True,
- related_name='nested_sources',
- verbose_name='Intermediate target object',
- on_delete=models.CASCADE)
+ target = models.ForeignKey(
+ NullableForeignKeySource,
+ null=True,
+ blank=True,
+ related_name="nested_sources",
+ verbose_name="Intermediate target object",
+ on_delete=models.CASCADE,
+ )
# OneToOne
@@ -102,13 +125,21 @@ class OneToOneTarget(RESTFrameworkModel):
class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(
- OneToOneTarget, null=True, blank=True,
- related_name='nullable_source', on_delete=models.CASCADE)
+ OneToOneTarget,
+ null=True,
+ blank=True,
+ related_name="nullable_source",
+ on_delete=models.CASCADE,
+ )
class OneToOnePKSource(RESTFrameworkModel):
""" Test model where the primary key is a OneToOneField with another model. """
+
name = models.CharField(max_length=100)
target = models.OneToOneField(
- OneToOneTarget, primary_key=True,
- related_name='required_source', on_delete=models.CASCADE)
+ OneToOneTarget,
+ primary_key=True,
+ related_name="required_source",
+ on_delete=models.CASCADE,
+ )
diff --git a/tests/test_api_client.py b/tests/test_api_client.py
index e4354ec60..f2c14c0e1 100644
--- a/tests/test_api_client.py
+++ b/tests/test_api_client.py
@@ -18,52 +18,75 @@ from rest_framework.views import APIView
def get_schema():
return coreapi.Document(
- url='https://api.example.com/',
- title='Example API',
+ url="https://api.example.com/",
+ title="Example API",
content={
- 'simple_link': coreapi.Link('/example/', description='example link'),
- 'headers': coreapi.Link('/headers/'),
- 'location': {
- 'query': coreapi.Link('/example/', fields=[
- coreapi.Field(name='example', schema=coreschema.String(description='example field'))
- ]),
- 'form': coreapi.Link('/example/', action='post', fields=[
- coreapi.Field(name='example')
- ]),
- 'body': coreapi.Link('/example/', action='post', fields=[
- coreapi.Field(name='example', location='body')
- ]),
- 'path': coreapi.Link('/example/{id}', fields=[
- coreapi.Field(name='id', location='path')
- ])
+ "simple_link": coreapi.Link("/example/", description="example link"),
+ "headers": coreapi.Link("/headers/"),
+ "location": {
+ "query": coreapi.Link(
+ "/example/",
+ fields=[
+ coreapi.Field(
+ name="example",
+ schema=coreschema.String(description="example field"),
+ )
+ ],
+ ),
+ "form": coreapi.Link(
+ "/example/", action="post", fields=[coreapi.Field(name="example")]
+ ),
+ "body": coreapi.Link(
+ "/example/",
+ action="post",
+ fields=[coreapi.Field(name="example", location="body")],
+ ),
+ "path": coreapi.Link(
+ "/example/{id}", fields=[coreapi.Field(name="id", location="path")]
+ ),
},
- 'encoding': {
- 'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
- coreapi.Field(name='example')
- ]),
- 'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
- coreapi.Field(name='example', location='body')
- ]),
- 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
- coreapi.Field(name='example')
- ]),
- 'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
- coreapi.Field(name='example', location='body')
- ]),
- 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[
- coreapi.Field(name='example', location='body')
- ]),
+ "encoding": {
+ "multipart": coreapi.Link(
+ "/example/",
+ action="post",
+ encoding="multipart/form-data",
+ fields=[coreapi.Field(name="example")],
+ ),
+ "multipart-body": coreapi.Link(
+ "/example/",
+ action="post",
+ encoding="multipart/form-data",
+ fields=[coreapi.Field(name="example", location="body")],
+ ),
+ "urlencoded": coreapi.Link(
+ "/example/",
+ action="post",
+ encoding="application/x-www-form-urlencoded",
+ fields=[coreapi.Field(name="example")],
+ ),
+ "urlencoded-body": coreapi.Link(
+ "/example/",
+ action="post",
+ encoding="application/x-www-form-urlencoded",
+ fields=[coreapi.Field(name="example", location="body")],
+ ),
+ "raw_upload": coreapi.Link(
+ "/upload/",
+ action="post",
+ encoding="application/octet-stream",
+ fields=[coreapi.Field(name="example", location="body")],
+ ),
},
- 'response': {
- 'download': coreapi.Link('/download/'),
- 'text': coreapi.Link('/text/')
- }
- }
+ "response": {
+ "download": coreapi.Link("/download/"),
+ "text": coreapi.Link("/text/"),
+ },
+ },
)
def _iterlists(querydict):
- if hasattr(querydict, 'iterlists'):
+ if hasattr(querydict, "iterlists"):
return querydict.iterlists()
return querydict.lists()
@@ -73,8 +96,7 @@ def _get_query_params(request):
# than one item is present for a given key.
return {
key: (value[0] if len(value) == 1 else value)
- for key, value in
- _iterlists(request.query_params)
+ for key, value in _iterlists(request.query_params)
}
@@ -83,7 +105,7 @@ def _get_data(request):
return request.data
# Coerce multidict into regular dict, and remove files to
# make assertions simpler.
- if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'):
+ if hasattr(request.data, "iterlists") or hasattr(request.data, "lists"):
# Use a list value if a QueryDict contains multiple items for a key.
return {
key: value[0] if len(value) == 1 else value
@@ -91,9 +113,7 @@ def _get_data(request):
if key not in request.FILES
}
return {
- key: value
- for key, value in request.data.items()
- if key not in request.FILES
+ key: value for key, value in request.data.items() if key not in request.FILES
}
@@ -101,7 +121,7 @@ def _get_files(request):
if not request.FILES:
return {}
return {
- key: {'name': value.name, 'content': value.read()}
+ key: {"name": value.name, "content": value.read()}
for key, value in request.FILES.items()
}
@@ -116,210 +136,207 @@ class SchemaView(APIView):
class ListView(APIView):
def get(self, request):
- return Response({
- 'method': request.method,
- 'query_params': _get_query_params(request)
- })
+ return Response(
+ {"method": request.method, "query_params": _get_query_params(request)}
+ )
def post(self, request):
if request.content_type:
- content_type = request.content_type.split(';')[0]
+ content_type = request.content_type.split(";")[0]
else:
content_type = None
- return Response({
- 'method': request.method,
- 'query_params': _get_query_params(request),
- 'data': _get_data(request),
- 'files': _get_files(request),
- 'content_type': content_type
- })
+ return Response(
+ {
+ "method": request.method,
+ "query_params": _get_query_params(request),
+ "data": _get_data(request),
+ "files": _get_files(request),
+ "content_type": content_type,
+ }
+ )
class DetailView(APIView):
def get(self, request, id):
- return Response({
- 'id': id,
- 'method': request.method,
- 'query_params': _get_query_params(request)
- })
+ return Response(
+ {
+ "id": id,
+ "method": request.method,
+ "query_params": _get_query_params(request),
+ }
+ )
class UploadView(APIView):
parser_classes = [FileUploadParser]
def post(self, request):
- return Response({
- 'method': request.method,
- 'files': _get_files(request),
- 'content_type': request.content_type
- })
+ return Response(
+ {
+ "method": request.method,
+ "files": _get_files(request),
+ "content_type": request.content_type,
+ }
+ )
class DownloadView(APIView):
def get(self, request):
- return HttpResponse('some file content', content_type='image/png')
+ return HttpResponse("some file content", content_type="image/png")
class TextView(APIView):
def get(self, request):
- return HttpResponse('123', content_type='text/plain')
+ return HttpResponse("123", content_type="text/plain")
class HeadersView(APIView):
def get(self, request):
headers = {
- key[5:].replace('_', '-'): value
+ key[5:].replace("_", "-"): value
for key, value in request.META.items()
- if key.startswith('HTTP_')
+ if key.startswith("HTTP_")
}
- return Response({
- 'method': request.method,
- 'headers': headers
- })
+ return Response({"method": request.method, "headers": headers})
urlpatterns = [
- url(r'^$', SchemaView.as_view()),
- url(r'^example/$', ListView.as_view()),
- url(r'^example/(?P[0-9]+)/$', DetailView.as_view()),
- url(r'^upload/$', UploadView.as_view()),
- url(r'^download/$', DownloadView.as_view()),
- url(r'^text/$', TextView.as_view()),
- url(r'^headers/$', HeadersView.as_view()),
+ url(r"^$", SchemaView.as_view()),
+ url(r"^example/$", ListView.as_view()),
+ url(r"^example/(?P[0-9]+)/$", DetailView.as_view()),
+ url(r"^upload/$", UploadView.as_view()),
+ url(r"^download/$", DownloadView.as_view()),
+ url(r"^text/$", TextView.as_view()),
+ url(r"^headers/$", HeadersView.as_view()),
]
-@unittest.skipUnless(coreapi, 'coreapi not installed')
-@override_settings(ROOT_URLCONF='tests.test_api_client')
+@unittest.skipUnless(coreapi, "coreapi not installed")
+@override_settings(ROOT_URLCONF="tests.test_api_client")
class APIClientTests(APITestCase):
def test_api_client(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- assert schema.title == 'Example API'
- assert schema.url == 'https://api.example.com/'
- assert schema['simple_link'].description == 'example link'
- assert schema['location']['query'].fields[0].schema.description == 'example field'
- data = client.action(schema, ['simple_link'])
- expected = {
- 'method': 'GET',
- 'query_params': {}
- }
+ schema = client.get("http://api.example.com/")
+ assert schema.title == "Example API"
+ assert schema.url == "https://api.example.com/"
+ assert schema["simple_link"].description == "example link"
+ assert (
+ schema["location"]["query"].fields[0].schema.description == "example field"
+ )
+ data = client.action(schema, ["simple_link"])
+ expected = {"method": "GET", "query_params": {}}
assert data == expected
def test_query_params(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['location', 'query'], params={'example': 123})
- expected = {
- 'method': 'GET',
- 'query_params': {'example': '123'}
- }
+ schema = client.get("http://api.example.com/")
+ data = client.action(schema, ["location", "query"], params={"example": 123})
+ expected = {"method": "GET", "query_params": {"example": "123"}}
assert data == expected
def test_session_headers(self):
client = CoreAPIClient()
- client.session.headers.update({'X-Custom-Header': 'foo'})
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['headers'])
- assert data['headers']['X-CUSTOM-HEADER'] == 'foo'
+ client.session.headers.update({"X-Custom-Header": "foo"})
+ schema = client.get("http://api.example.com/")
+ data = client.action(schema, ["headers"])
+ assert data["headers"]["X-CUSTOM-HEADER"] == "foo"
def test_query_params_with_multiple_values(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
- expected = {
- 'method': 'GET',
- 'query_params': {'example': ['1', '2', '3']}
- }
+ schema = client.get("http://api.example.com/")
+ data = client.action(
+ schema, ["location", "query"], params={"example": [1, 2, 3]}
+ )
+ expected = {"method": "GET", "query_params": {"example": ["1", "2", "3"]}}
assert data == expected
def test_form_params(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['location', 'form'], params={'example': 123})
+ schema = client.get("http://api.example.com/")
+ data = client.action(schema, ["location", "form"], params={"example": 123})
expected = {
- 'method': 'POST',
- 'content_type': 'application/json',
- 'query_params': {},
- 'data': {'example': 123},
- 'files': {}
+ "method": "POST",
+ "content_type": "application/json",
+ "query_params": {},
+ "data": {"example": 123},
+ "files": {},
}
assert data == expected
def test_body_params(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['location', 'body'], params={'example': 123})
+ schema = client.get("http://api.example.com/")
+ data = client.action(schema, ["location", "body"], params={"example": 123})
expected = {
- 'method': 'POST',
- 'content_type': 'application/json',
- 'query_params': {},
- 'data': 123,
- 'files': {}
+ "method": "POST",
+ "content_type": "application/json",
+ "query_params": {},
+ "data": 123,
+ "files": {},
}
assert data == expected
def test_path_params(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['location', 'path'], params={'id': 123})
- expected = {
- 'method': 'GET',
- 'query_params': {},
- 'id': '123'
- }
+ schema = client.get("http://api.example.com/")
+ data = client.action(schema, ["location", "path"], params={"id": 123})
+ expected = {"method": "GET", "query_params": {}, "id": "123"}
assert data == expected
def test_multipart_encoding(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile() as temp:
- temp.write(b'example file content')
+ temp.write(b"example file content")
temp.flush()
temp.seek(0)
name = os.path.basename(temp.name)
- data = client.action(schema, ['encoding', 'multipart'], params={'example': temp})
+ data = client.action(
+ schema, ["encoding", "multipart"], params={"example": temp}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'multipart/form-data',
- 'query_params': {},
- 'data': {},
- 'files': {'example': {'name': name, 'content': 'example file content'}}
+ "method": "POST",
+ "content_type": "multipart/form-data",
+ "query_params": {},
+ "data": {},
+ "files": {"example": {"name": name, "content": "example file content"}},
}
assert data == expected
def test_multipart_encoding_no_file(self):
# When no file is included, multipart encoding should still be used.
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
+ data = client.action(schema, ["encoding", "multipart"], params={"example": 123})
expected = {
- 'method': 'POST',
- 'content_type': 'multipart/form-data',
- 'query_params': {},
- 'data': {'example': '123'},
- 'files': {}
+ "method": "POST",
+ "content_type": "multipart/form-data",
+ "query_params": {},
+ "data": {"example": "123"},
+ "files": {},
}
assert data == expected
def test_multipart_encoding_multiple_values(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
+ data = client.action(
+ schema, ["encoding", "multipart"], params={"example": [1, 2, 3]}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'multipart/form-data',
- 'query_params': {},
- 'data': {'example': ['1', '2', '3']},
- 'files': {}
+ "method": "POST",
+ "content_type": "multipart/form-data",
+ "query_params": {},
+ "data": {"example": ["1", "2", "3"]},
+ "files": {},
}
assert data == expected
@@ -328,17 +345,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- example = File(name='example.txt', content='123')
- data = client.action(schema, ['encoding', 'multipart'], params={'example': example})
+ example = File(name="example.txt", content="123")
+ data = client.action(
+ schema, ["encoding", "multipart"], params={"example": example}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'multipart/form-data',
- 'query_params': {},
- 'data': {},
- 'files': {'example': {'name': 'example.txt', 'content': '123'}}
+ "method": "POST",
+ "content_type": "multipart/form-data",
+ "query_params": {},
+ "data": {},
+ "files": {"example": {"name": "example.txt", "content": "123"}},
}
assert data == expected
@@ -346,17 +365,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
- data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example})
+ example = {"foo": File(name="example.txt", content="123"), "bar": "abc"}
+ data = client.action(
+ schema, ["encoding", "multipart-body"], params={"example": example}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'multipart/form-data',
- 'query_params': {},
- 'data': {'bar': 'abc'},
- 'files': {'foo': {'name': 'example.txt', 'content': '123'}}
+ "method": "POST",
+ "content_type": "multipart/form-data",
+ "query_params": {},
+ "data": {"bar": "abc"},
+ "files": {"foo": {"name": "example.txt", "content": "123"}},
}
assert data == expected
@@ -364,40 +385,48 @@ class APIClientTests(APITestCase):
def test_urlencoded_encoding(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
+ schema = client.get("http://api.example.com/")
+ data = client.action(
+ schema, ["encoding", "urlencoded"], params={"example": 123}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'application/x-www-form-urlencoded',
- 'query_params': {},
- 'data': {'example': '123'},
- 'files': {}
+ "method": "POST",
+ "content_type": "application/x-www-form-urlencoded",
+ "query_params": {},
+ "data": {"example": "123"},
+ "files": {},
}
assert data == expected
def test_urlencoded_encoding_multiple_values(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
+ schema = client.get("http://api.example.com/")
+ data = client.action(
+ schema, ["encoding", "urlencoded"], params={"example": [1, 2, 3]}
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'application/x-www-form-urlencoded',
- 'query_params': {},
- 'data': {'example': ['1', '2', '3']},
- 'files': {}
+ "method": "POST",
+ "content_type": "application/x-www-form-urlencoded",
+ "query_params": {},
+ "data": {"example": ["1", "2", "3"]},
+ "files": {},
}
assert data == expected
def test_urlencoded_encoding_in_body(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
- data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
+ schema = client.get("http://api.example.com/")
+ data = client.action(
+ schema,
+ ["encoding", "urlencoded-body"],
+ params={"example": {"foo": 123, "bar": True}},
+ )
expected = {
- 'method': 'POST',
- 'content_type': 'application/x-www-form-urlencoded',
- 'query_params': {},
- 'data': {'foo': '123', 'bar': 'true'},
- 'files': {}
+ "method": "POST",
+ "content_type": "application/x-www-form-urlencoded",
+ "query_params": {},
+ "data": {"foo": "123", "bar": "true"},
+ "files": {},
}
assert data == expected
@@ -405,20 +434,22 @@ class APIClientTests(APITestCase):
def test_raw_upload(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile(delete=False) as temp:
- temp.write(b'example file content')
+ temp.write(b"example file content")
temp.flush()
temp.seek(0)
name = os.path.basename(temp.name)
- data = client.action(schema, ['encoding', 'raw_upload'], params={'example': temp})
+ data = client.action(
+ schema, ["encoding", "raw_upload"], params={"example": temp}
+ )
expected = {
- 'method': 'POST',
- 'files': {'file': {'name': name, 'content': 'example file content'}},
- 'content_type': 'application/octet-stream'
+ "method": "POST",
+ "files": {"file": {"name": name, "content": "example file content"}},
+ "content_type": "application/octet-stream",
}
assert data == expected
@@ -426,15 +457,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- example = File('example.txt', '123')
- data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
+ example = File("example.txt", "123")
+ data = client.action(
+ schema, ["encoding", "raw_upload"], params={"example": example}
+ )
expected = {
- 'method': 'POST',
- 'files': {'file': {'name': 'example.txt', 'content': '123'}},
- 'content_type': 'text/plain'
+ "method": "POST",
+ "files": {"file": {"name": "example.txt", "content": "123"}},
+ "content_type": "text/plain",
}
assert data == expected
@@ -442,15 +475,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- example = File('example.txt', '123', 'text/html')
- data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
+ example = File("example.txt", "123", "text/html")
+ data = client.action(
+ schema, ["encoding", "raw_upload"], params={"example": example}
+ )
expected = {
- 'method': 'POST',
- 'files': {'file': {'name': 'example.txt', 'content': '123'}},
- 'content_type': 'text/html'
+ "method": "POST",
+ "files": {"file": {"name": "example.txt", "content": "123"}},
+ "content_type": "text/html",
}
assert data == expected
@@ -458,17 +493,17 @@ class APIClientTests(APITestCase):
def test_text_response(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- data = client.action(schema, ['response', 'text'])
+ data = client.action(schema, ["response", "text"])
- expected = '123'
+ expected = "123"
assert data == expected
def test_download_response(self):
client = CoreAPIClient()
- schema = client.get('http://api.example.com/')
+ schema = client.get("http://api.example.com/")
- data = client.action(schema, ['response', 'download'])
- assert data.basename == 'download.png'
- assert data.read() == b'some file content'
+ data = client.action(schema, ["response", "download"])
+ assert data.basename == "download.png"
+ assert data.read() == b"some file content"
diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py
index bddd480a5..9dab585d7 100644
--- a/tests/test_atomic_requests.py
+++ b/tests/test_atomic_requests.py
@@ -14,13 +14,14 @@ from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from tests.models import BasicModel
+
factory = APIRequestFactory()
class BasicView(APIView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
- return Response({'method': 'GET'})
+ return Response({"method": "GET"})
class ErrorView(APIView):
@@ -45,25 +46,23 @@ class NonAtomicAPIExceptionView(APIView):
raise Http404
-urlpatterns = (
- url(r'^$', NonAtomicAPIExceptionView.as_view()),
-)
+urlpatterns = (url(r"^$", NonAtomicAPIExceptionView.as_view()),)
@unittest.skipUnless(
connection.features.uses_savepoints,
- "'atomic' requires transactions and savepoints."
+ "'atomic' requires transactions and savepoints.",
)
class DBTransactionTests(TestCase):
def setUp(self):
self.view = BasicView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+ connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_no_exception_commit_transaction(self):
- request = factory.post('/')
+ request = factory.post("/")
with self.assertNumQueries(1):
response = self.view(request)
@@ -74,15 +73,15 @@ class DBTransactionTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
- "'atomic' requires transactions and savepoints."
+ "'atomic' requires transactions and savepoints.",
)
class DBTransactionErrorTests(TestCase):
def setUp(self):
self.view = ErrorView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+ connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_generic_exception_delegate_transaction_management(self):
"""
@@ -91,7 +90,7 @@ class DBTransactionErrorTests(TestCase):
We let django deal with the transaction when it will catch the Exception.
"""
- request = factory.post('/')
+ request = factory.post("/")
with self.assertNumQueries(3):
# 1 - begin savepoint
# 2 - insert
@@ -104,21 +103,21 @@ class DBTransactionErrorTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
- "'atomic' requires transactions and savepoints."
+ "'atomic' requires transactions and savepoints.",
)
class DBTransactionAPIExceptionTests(TestCase):
def setUp(self):
self.view = APIExceptionView.as_view()
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+ connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction(self):
"""
Transaction is rollbacked by our transaction atomic block.
"""
- request = factory.post('/')
+ request = factory.post("/")
num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries):
# 1 - begin savepoint
@@ -134,18 +133,18 @@ class DBTransactionAPIExceptionTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
- "'atomic' requires transactions and savepoints."
+ "'atomic' requires transactions and savepoints.",
)
-@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
+@override_settings(ROOT_URLCONF="tests.test_atomic_requests")
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
def setUp(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = True
+ connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
- connections.databases['default']['ATOMIC_REQUESTS'] = False
+ connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction_non_atomic_view(self):
- response = self.client.get('/')
+ response = self.client.get("/")
# without checking connection.in_atomic_block view raises 500
# due attempt to rollback without transaction
diff --git a/tests/test_authtoken.py b/tests/test_authtoken.py
index c8957f978..aaa8ef29c 100644
--- a/tests/test_authtoken.py
+++ b/tests/test_authtoken.py
@@ -6,44 +6,43 @@ from django.test import TestCase
from django.utils.six import StringIO
from rest_framework.authtoken.admin import TokenAdmin
-from rest_framework.authtoken.management.commands.drf_create_token import \
- Command as AuthTokenCommand
+from rest_framework.authtoken.management.commands.drf_create_token import (
+ Command as AuthTokenCommand,
+)
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError
class AuthTokenTests(TestCase):
-
def setUp(self):
self.site = site
- self.user = User.objects.create_user(username='test_user')
- self.token = Token.objects.create(key='test token', user=self.user)
+ self.user = User.objects.create_user(username="test_user")
+ self.token = Token.objects.create(key="test token", user=self.user)
def test_model_admin_displayed_fields(self):
mock_request = object()
token_admin = TokenAdmin(self.token, self.site)
- assert token_admin.get_fields(mock_request) == ('user',)
+ assert token_admin.get_fields(mock_request) == ("user",)
def test_token_string_representation(self):
- assert str(self.token) == 'test token'
+ assert str(self.token) == "test token"
def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError):
AuthTokenSerializer().validate({})
def test_whitespace_in_password(self):
- data = {'username': self.user.username, 'password': 'test pass '}
- self.user.set_password(data['password'])
+ data = {"username": self.user.username, "password": "test pass "}
+ self.user.set_password(data["password"])
self.user.save()
assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
-
def setUp(self):
self.site = site
- self.user = User.objects.create_user(username='test_user')
+ self.user = User.objects.create_user(username="test_user")
def test_command_create_user_token(self):
token = AuthTokenCommand().create_user_token(self.user.username, False)
@@ -53,7 +52,7 @@ class AuthTokenCommandTests(TestCase):
def test_command_create_user_token_invalid_user(self):
with pytest.raises(User.DoesNotExist):
- AuthTokenCommand().create_user_token('not_existing_user', False)
+ AuthTokenCommand().create_user_token("not_existing_user", False)
def test_command_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False)
@@ -74,12 +73,12 @@ class AuthTokenCommandTests(TestCase):
def test_command_raising_error_for_invalid_user(self):
out = StringIO()
with pytest.raises(CommandError):
- call_command('drf_create_token', 'not_existing_user', stdout=out)
+ call_command("drf_create_token", "not_existing_user", stdout=out)
def test_command_output(self):
out = StringIO()
- call_command('drf_create_token', self.user.username, stdout=out)
+ call_command("drf_create_token", self.user.username, stdout=out)
token_saved = Token.objects.first()
- self.assertIn('Generated token', out.getvalue())
+ self.assertIn("Generated token", out.getvalue())
self.assertIn(self.user.username, out.getvalue())
self.assertIn(token_saved.key, out.getvalue())
diff --git a/tests/test_bound_fields.py b/tests/test_bound_fields.py
index e588ae623..e369dc759 100644
--- a/tests/test_bound_fields.py
+++ b/tests/test_bound_fields.py
@@ -11,41 +11,43 @@ class TestSimpleBoundField:
serializer = ExampleSerializer()
- assert serializer['text'].value == ''
- assert serializer['text'].errors is None
- assert serializer['text'].name == 'text'
- assert serializer['amount'].value is None
- assert serializer['amount'].errors is None
- assert serializer['amount'].name == 'amount'
+ assert serializer["text"].value == ""
+ assert serializer["text"].errors is None
+ assert serializer["text"].name == "text"
+ assert serializer["amount"].value is None
+ assert serializer["amount"].errors is None
+ assert serializer["amount"].name == "amount"
def test_populated_bound_field(self):
class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100)
amount = serializers.IntegerField()
- serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123})
+ serializer = ExampleSerializer(data={"text": "abc", "amount": 123})
assert serializer.is_valid()
- assert serializer['text'].value == 'abc'
- assert serializer['text'].errors is None
- assert serializer['text'].name == 'text'
- assert serializer['amount'].value is 123
- assert serializer['amount'].errors is None
- assert serializer['amount'].name == 'amount'
+ assert serializer["text"].value == "abc"
+ assert serializer["text"].errors is None
+ assert serializer["text"].name == "text"
+ assert serializer["amount"].value is 123
+ assert serializer["amount"].errors is None
+ assert serializer["amount"].name == "amount"
def test_error_bound_field(self):
class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100)
amount = serializers.IntegerField()
- serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123})
+ serializer = ExampleSerializer(data={"text": "x" * 1000, "amount": 123})
serializer.is_valid()
- assert serializer['text'].value == 'x' * 1000
- assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.']
- assert serializer['text'].name == 'text'
- assert serializer['amount'].value is 123
- assert serializer['amount'].errors is None
- assert serializer['amount'].name == 'amount'
+ assert serializer["text"].value == "x" * 1000
+ assert serializer["text"].errors == [
+ "Ensure this field has no more than 100 characters."
+ ]
+ assert serializer["text"].name == "text"
+ assert serializer["amount"].value is 123
+ assert serializer["amount"].errors is None
+ assert serializer["amount"].name == "amount"
def test_delete_field(self):
class ExampleSerializer(serializers.Serializer):
@@ -53,41 +55,45 @@ class TestSimpleBoundField:
amount = serializers.IntegerField()
serializer = ExampleSerializer()
- del serializer.fields['text']
- assert 'text' not in serializer.fields
+ del serializer.fields["text"]
+ assert "text" not in serializer.fields
def test_as_form_fields(self):
class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField()
null_field = serializers.IntegerField(allow_null=True)
- serializer = ExampleSerializer(data={'bool_field': False, 'null_field': None})
+ serializer = ExampleSerializer(data={"bool_field": False, "null_field": None})
assert serializer.is_valid()
- assert serializer['bool_field'].as_form_field().value == ''
- assert serializer['null_field'].as_form_field().value == ''
+ assert serializer["bool_field"].as_form_field().value == ""
+ assert serializer["null_field"].as_form_field().value == ""
def test_rendering_boolean_field(self):
from rest_framework.renderers import HTMLFormRenderer
class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField(
- style={'base_template': 'checkbox.html', 'template_pack': 'rest_framework/vertical'})
+ style={
+ "base_template": "checkbox.html",
+ "template_pack": "rest_framework/vertical",
+ }
+ )
- serializer = ExampleSerializer(data={'bool_field': True})
+ serializer = ExampleSerializer(data={"bool_field": True})
assert serializer.is_valid()
renderer = HTMLFormRenderer()
- rendered = renderer.render_field(serializer['bool_field'], {})
+ rendered = renderer.render_field(serializer["bool_field"], {})
expected_packed = (
''
''
- '"
+ ""
+ ""
)
- rendered_packed = ''.join(rendered.split())
+ rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed
@@ -103,15 +109,15 @@ class TestNestedBoundField:
serializer = ExampleSerializer()
- assert serializer['text'].value == ''
- assert serializer['text'].errors is None
- assert serializer['text'].name == 'text'
- assert serializer['nested']['more_text'].value == ''
- assert serializer['nested']['more_text'].errors is None
- assert serializer['nested']['more_text'].name == 'nested.more_text'
- assert serializer['nested']['amount'].value is None
- assert serializer['nested']['amount'].errors is None
- assert serializer['nested']['amount'].name == 'nested.amount'
+ assert serializer["text"].value == ""
+ assert serializer["text"].errors is None
+ assert serializer["text"].name == "text"
+ assert serializer["nested"]["more_text"].value == ""
+ assert serializer["nested"]["more_text"].errors is None
+ assert serializer["nested"]["more_text"].name == "nested.more_text"
+ assert serializer["nested"]["amount"].value is None
+ assert serializer["nested"]["amount"].errors is None
+ assert serializer["nested"]["amount"].name == "nested.amount"
def test_as_form_fields(self):
class Nested(serializers.Serializer):
@@ -121,10 +127,12 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer):
nested = Nested()
- serializer = ExampleSerializer(data={'nested': {'bool_field': False, 'null_field': None}})
+ serializer = ExampleSerializer(
+ data={"nested": {"bool_field": False, "null_field": None}}
+ )
assert serializer.is_valid()
- assert serializer['nested']['bool_field'].as_form_field().value == ''
- assert serializer['nested']['null_field'].as_form_field().value == ''
+ assert serializer["nested"]["bool_field"].as_form_field().value == ""
+ assert serializer["nested"]["null_field"].as_form_field().value == ""
def test_rendering_nested_fields_with_none_value(self):
from rest_framework.renderers import HTMLFormRenderer
@@ -139,28 +147,30 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer):
nested2 = Nested2()
- serializer = ExampleSerializer(data={'nested2': {'nested1': None, 'text_field': 'test'}})
+ serializer = ExampleSerializer(
+ data={"nested2": {"nested1": None, "text_field": "test"}}
+ )
assert serializer.is_valid()
renderer = HTMLFormRenderer()
for field in serializer:
rendered = renderer.render_field(field, {})
expected_packed = (
- '"
)
- rendered_packed = ''.join(rendered.split())
+ rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed
@@ -170,7 +180,7 @@ class TestJSONBoundField:
json_field = serializers.JSONField()
data = QueryDict(mutable=True)
- data.update({'json_field': '{"some": ["json"}'})
+ data.update({"json_field": '{"some": ["json"}'})
serializer = TestSerializer(data=data)
assert serializer.is_valid() is False
- assert serializer['json_field'].as_form_field().value == '{"some": ["json"}'
+ assert serializer["json_field"].as_form_field().value == '{"some": ["json"}'
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
index 13dd41ff3..a8804cfd1 100644
--- a/tests/test_decorators.py
+++ b/tests/test_decorators.py
@@ -6,9 +6,16 @@ from django.test import TestCase
from rest_framework import RemovedInDRF310Warning, status
from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import (
- action, api_view, authentication_classes, detail_route, list_route,
- parser_classes, permission_classes, renderer_classes, schema,
- throttle_classes
+ action,
+ api_view,
+ authentication_classes,
+ detail_route,
+ list_route,
+ parser_classes,
+ permission_classes,
+ renderer_classes,
+ schema,
+ throttle_classes,
)
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
@@ -21,7 +28,6 @@ from rest_framework.views import APIView
class DecoratorTestCase(TestCase):
-
def setUp(self):
self.factory = APIRequestFactory()
@@ -38,7 +44,7 @@ class DecoratorTestCase(TestCase):
def view(request):
return Response()
- request = self.factory.get('/')
+ request = self.factory.get("/")
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
@@ -47,108 +53,102 @@ class DecoratorTestCase(TestCase):
"""
with self.assertRaises(AssertionError):
- @api_view('GET')
+
+ @api_view("GET")
def view(request):
return Response()
def test_calling_method(self):
-
- @api_view(['GET'])
+ @api_view(["GET"])
def view(request):
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- request = self.factory.post('/')
+ request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self):
-
- @api_view(['GET', 'PUT'])
+ @api_view(["GET", "PUT"])
def view(request):
return Response({})
- request = self.factory.put('/')
+ request = self.factory.put("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- request = self.factory.post('/')
+ request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self):
-
- @api_view(['GET', 'PATCH'])
+ @api_view(["GET", "PATCH"])
def view(request):
return Response({})
- request = self.factory.patch('/')
+ request = self.factory.patch("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- request = self.factory.post('/')
+ request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self):
-
- @api_view(['GET'])
+ @api_view(["GET"])
@renderer_classes([JSONRenderer])
def view(request):
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
response = view(request)
assert isinstance(response.accepted_renderer, JSONRenderer)
def test_parser_classes(self):
-
- @api_view(['GET'])
+ @api_view(["GET"])
@parser_classes([JSONParser])
def view(request):
assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser)
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
view(request)
def test_authentication_classes(self):
-
- @api_view(['GET'])
+ @api_view(["GET"])
@authentication_classes([BasicAuthentication])
def view(request):
assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
view(request)
def test_permission_classes(self):
-
- @api_view(['GET'])
+ @api_view(["GET"])
@permission_classes([IsAuthenticated])
def view(request):
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
- rate = '1/day'
+ rate = "1/day"
- @api_view(['GET'])
+ @api_view(["GET"])
@throttle_classes([OncePerDayUserThrottle])
def view(request):
return Response({})
- request = self.factory.get('/')
+ request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
@@ -159,10 +159,11 @@ class DecoratorTestCase(TestCase):
"""
Checks CustomSchema class is set on view
"""
+
class CustomSchema(AutoSchema):
pass
- @api_view(['GET'])
+ @api_view(["GET"])
@schema(CustomSchema())
def view(request):
return Response({})
@@ -171,23 +172,23 @@ class DecoratorTestCase(TestCase):
class ActionDecoratorTestCase(TestCase):
-
def test_defaults(self):
@action(detail=True)
def test_action(request):
"""Description"""
- assert test_action.mapping == {'get': 'test_action'}
+ assert test_action.mapping == {"get": "test_action"}
assert test_action.detail is True
- assert test_action.url_path == 'test_action'
- assert test_action.url_name == 'test-action'
+ assert test_action.url_path == "test_action"
+ assert test_action.url_name == "test-action"
assert test_action.kwargs == {
- 'name': 'Test action',
- 'description': 'Description',
+ "name": "Test action",
+ "description": "Description",
}
def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo:
+
@action()
def test_action(request):
raise NotImplementedError
@@ -201,6 +202,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError
for name in APIView.http_method_names:
+
def method():
raise NotImplementedError
@@ -222,36 +224,30 @@ class ActionDecoratorTestCase(TestCase):
def test_action(request):
raise NotImplementedError
- assert test_action.kwargs == {
- 'description': None,
- 'name': 'Test action',
- }
+ assert test_action.kwargs == {"description": None, "name": "Test action"}
# name kwarg supersedes name generation
- @action(detail=True, name='test name')
+ @action(detail=True, name="test name")
def test_action(request):
raise NotImplementedError
- assert test_action.kwargs == {
- 'description': None,
- 'name': 'test name',
- }
+ assert test_action.kwargs == {"description": None, "name": "test name"}
# suffix kwarg supersedes name generation
- @action(detail=True, suffix='Suffix')
+ @action(detail=True, suffix="Suffix")
def test_action(request):
raise NotImplementedError
- assert test_action.kwargs == {
- 'description': None,
- 'suffix': 'Suffix',
- }
+ assert test_action.kwargs == {"description": None, "suffix": "Suffix"}
# name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo:
- action(detail=True, name='test name', suffix='Suffix')
+ action(detail=True, name="test name", suffix="Suffix")
- assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
+ assert (
+ str(excinfo.value)
+ == "`name` and `suffix` are mutually exclusive arguments."
+ )
def test_method_mapping(self):
@action(detail=False)
@@ -263,7 +259,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError
# The secondary handler methods should not have the action attributes
- for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
+ for name in ["mapping", "detail", "url_path", "url_name", "kwargs"]:
assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self):
@@ -273,6 +269,7 @@ class ActionDecoratorTestCase(TestCase):
msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg):
+
@test_action.mapping.get
def test_action_get(request):
raise NotImplementedError
@@ -282,15 +279,19 @@ class ActionDecoratorTestCase(TestCase):
def test_action():
raise NotImplementedError
- msg = ("Method mapping does not behave like the property decorator. You "
- "cannot use the same method name for each mapping declaration.")
+ msg = (
+ "Method mapping does not behave like the property decorator. You "
+ "cannot use the same method name for each mapping declaration."
+ )
with self.assertRaisesMessage(AssertionError, msg):
+
@test_action.mapping.post
def test_action():
raise NotImplementedError
def test_detail_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
+
@detail_route()
def view(request):
raise NotImplementedError
@@ -304,6 +305,7 @@ class ActionDecoratorTestCase(TestCase):
def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
+
@list_route()
def view(request):
raise NotImplementedError
@@ -318,9 +320,10 @@ class ActionDecoratorTestCase(TestCase):
def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path`
with pytest.warns(RemovedInDRF310Warning):
- @list_route(url_path='foo_bar')
+
+ @list_route(url_path="foo_bar")
def view(request):
raise NotImplementedError
- assert view.url_path == 'foo_bar'
- assert view.url_name == 'foo-bar'
+ assert view.url_path == "foo_bar"
+ assert view.url_name == "foo-bar"
diff --git a/tests/test_description.py b/tests/test_description.py
index 702e56332..9aac9555f 100644
--- a/tests/test_description.py
+++ b/tests/test_description.py
@@ -9,6 +9,7 @@ from rest_framework.compat import apply_markdown
from rest_framework.utils.formatting import dedent
from rest_framework.views import APIView
+
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
====================
@@ -81,28 +82,34 @@ class TestViewNamesAndDescriptions(TestCase):
"""
Ensure view names are based on the class name.
"""
+
class MockView(APIView):
pass
- assert MockView().get_view_name() == 'Mock'
+
+ assert MockView().get_view_name() == "Mock"
def test_view_name_uses_name_attribute(self):
class MockView(APIView):
- name = 'Foo'
- assert MockView().get_view_name() == 'Foo'
+ name = "Foo"
+
+ assert MockView().get_view_name() == "Foo"
def test_view_name_uses_suffix_attribute(self):
class MockView(APIView):
- suffix = 'List'
- assert MockView().get_view_name() == 'Mock List'
+ suffix = "List"
+
+ assert MockView().get_view_name() == "Mock List"
def test_view_name_preferences_name_over_suffix(self):
class MockView(APIView):
- name = 'Foo'
- suffix = 'List'
- assert MockView().get_view_name() == 'Foo'
+ name = "Foo"
+ suffix = "List"
+
+ assert MockView().get_view_name() == "Foo"
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
+
class MockView(APIView):
"""an example docstring
====================
@@ -130,23 +137,28 @@ class TestViewNamesAndDescriptions(TestCase):
def test_view_description_uses_description_attribute(self):
class MockView(APIView):
- description = 'Foo'
- assert MockView().get_view_description() == 'Foo'
+ description = "Foo"
+
+ assert MockView().get_view_description() == "Foo"
def test_view_description_allows_empty_description(self):
class MockView(APIView):
"""Description."""
- description = ''
- assert MockView().get_view_description() == ''
+
+ description = ""
+
+ assert MockView().get_view_description() == ""
def test_view_description_can_be_empty(self):
"""
Ensure that if a view has no docstring,
then it's description is the empty string.
"""
+
class MockView(APIView):
pass
- assert MockView().get_view_description() == ''
+
+ assert MockView().get_view_description() == ""
def test_view_description_can_be_promise(self):
"""
@@ -168,7 +180,7 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
__doc__ = MockLazyStr("a gettext string")
- assert MockView().get_view_description() == 'a gettext string'
+ assert MockView().get_view_description() == "a gettext string"
def test_markdown(self):
"""
@@ -176,21 +188,17 @@ class TestViewNamesAndDescriptions(TestCase):
"""
if apply_markdown:
md_applied = apply_markdown(DESCRIPTION)
- gte_21_match = (
- md_applied == (
- MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
- md_applied == (
- MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
- lt_21_match = (
- md_applied == (
- MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
- md_applied == (
- MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
+ gte_21_match = md_applied == (
+ MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE
+ ) or md_applied == (MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)
+ lt_21_match = md_applied == (
+ MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE
+ ) or md_applied == (MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE)
assert gte_21_match or lt_21_match
def test_dedent_tabs():
- result = 'first string\n\nsecond string'
+ result = "first string\n\nsecond string"
assert dedent(" first string\n\n second string") == result
assert dedent("first string\n\n second string") == result
assert dedent("\tfirst string\n\n\tsecond string") == result
diff --git a/tests/test_encoders.py b/tests/test_encoders.py
index 12eca8105..0fad7eb91 100644
--- a/tests/test_encoders.py
+++ b/tests/test_encoders.py
@@ -37,7 +37,7 @@ class JSONEncoderTests(TestCase):
current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc)
- assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
+ assert self.encoder.default(current_time_utc) == current_time.isoformat() + "Z"
def test_encode_time(self):
"""
@@ -76,7 +76,7 @@ class JSONEncoderTests(TestCase):
unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id)
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_encode_coreapi_raises_error(self):
"""
Tests encoding a coreapi objects raises proper error
diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py
index ce0ed8514..98478f908 100644
--- a/tests/test_exceptions.py
+++ b/tests/test_exceptions.py
@@ -6,13 +6,16 @@ from django.utils import six, translation
from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import (
- APIException, ErrorDetail, Throttled, _get_error_details, bad_request,
- server_error
+ APIException,
+ ErrorDetail,
+ Throttled,
+ _get_error_details,
+ bad_request,
+ server_error,
)
class ExceptionTestCase(TestCase):
-
def test_get_error_details(self):
example = "string"
@@ -20,91 +23,96 @@ class ExceptionTestCase(TestCase):
assert _get_error_details(lazy_example) == example
- assert isinstance(
- _get_error_details(lazy_example),
- ErrorDetail
- )
+ assert isinstance(_get_error_details(lazy_example), ErrorDetail)
- assert _get_error_details({'nested': lazy_example})['nested'] == example
+ assert _get_error_details({"nested": lazy_example})["nested"] == example
assert isinstance(
- _get_error_details({'nested': lazy_example})['nested'],
- ErrorDetail
+ _get_error_details({"nested": lazy_example})["nested"], ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example
- assert isinstance(
- _get_error_details([[lazy_example]])[0][0],
- ErrorDetail
- )
+ assert isinstance(_get_error_details([[lazy_example]])[0][0], ErrorDetail)
def test_get_full_details_with_throttling(self):
exception = Throttled()
assert exception.get_full_details() == {
- 'message': 'Request was throttled.', 'code': 'throttled'}
+ "message": "Request was throttled.",
+ "code": "throttled",
+ }
exception = Throttled(wait=2)
assert exception.get_full_details() == {
- 'message': 'Request was throttled. Expected available in {} seconds.'.format(2 if six.PY3 else 2.),
- 'code': 'throttled'}
+ "message": "Request was throttled. Expected available in {} seconds.".format(
+ 2 if six.PY3 else 2.0
+ ),
+ "code": "throttled",
+ }
- exception = Throttled(wait=2, detail='Slow down!')
+ exception = Throttled(wait=2, detail="Slow down!")
assert exception.get_full_details() == {
- 'message': 'Slow down! Expected available in {} seconds.'.format(2 if six.PY3 else 2.),
- 'code': 'throttled'}
+ "message": "Slow down! Expected available in {} seconds.".format(
+ 2 if six.PY3 else 2.0
+ ),
+ "code": "throttled",
+ }
class ErrorDetailTests(TestCase):
-
def test_eq(self):
- assert ErrorDetail('msg') == ErrorDetail('msg')
- assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
+ assert ErrorDetail("msg") == ErrorDetail("msg")
+ assert ErrorDetail("msg", "code") == ErrorDetail("msg", code="code")
- assert ErrorDetail('msg') == 'msg'
- assert ErrorDetail('msg', 'code') == 'msg'
+ assert ErrorDetail("msg") == "msg"
+ assert ErrorDetail("msg", "code") == "msg"
def test_ne(self):
- assert ErrorDetail('msg1') != ErrorDetail('msg2')
- assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
+ assert ErrorDetail("msg1") != ErrorDetail("msg2")
+ assert ErrorDetail("msg") != ErrorDetail("msg", code="invalid")
- assert ErrorDetail('msg1') != 'msg2'
- assert ErrorDetail('msg1', 'code') != 'msg2'
+ assert ErrorDetail("msg1") != "msg2"
+ assert ErrorDetail("msg1", "code") != "msg2"
def test_repr(self):
- assert repr(ErrorDetail('msg1')) == \
- 'ErrorDetail(string={!r}, code=None)'.format('msg1')
- assert repr(ErrorDetail('msg1', 'code')) == \
- 'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
+ assert repr(
+ ErrorDetail("msg1")
+ ) == "ErrorDetail(string={!r}, code=None)".format("msg1")
+ assert repr(
+ ErrorDetail("msg1", "code")
+ ) == "ErrorDetail(string={!r}, code={!r})".format("msg1", "code")
def test_str(self):
- assert str(ErrorDetail('msg1')) == 'msg1'
- assert str(ErrorDetail('msg1', 'code')) == 'msg1'
+ assert str(ErrorDetail("msg1")) == "msg1"
+ assert str(ErrorDetail("msg1", "code")) == "msg1"
def test_hash(self):
- assert hash(ErrorDetail('msg')) == hash('msg')
- assert hash(ErrorDetail('msg', 'code')) == hash('msg')
+ assert hash(ErrorDetail("msg")) == hash("msg")
+ assert hash(ErrorDetail("msg", "code")) == hash("msg")
class TranslationTests(TestCase):
-
- @translation.override('fr')
+ @translation.override("fr")
def test_message(self):
# this test largely acts as a sanity test to ensure the translation files are present.
- self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.')
- self.assertEqual(six.text_type(APIException()), 'Une erreur du serveur est survenue.')
+ self.assertEqual(
+ _("A server error occurred."), "Une erreur du serveur est survenue."
+ )
+ self.assertEqual(
+ six.text_type(APIException()), "Une erreur du serveur est survenue."
+ )
def test_server_error():
- request = RequestFactory().get('/')
+ request = RequestFactory().get("/")
response = server_error(request)
assert response.status_code == 500
- assert response["content-type"] == 'application/json'
+ assert response["content-type"] == "application/json"
def test_bad_request():
- request = RequestFactory().get('/')
- exception = Exception('Something went wrong — Not used')
+ request = RequestFactory().get("/")
+ exception = Exception("Something went wrong — Not used")
response = bad_request(request, exception)
assert response.status_code == 400
- assert response["content-type"] == 'application/json'
+ assert response["content-type"] == "application/json"
diff --git a/tests/test_fields.py b/tests/test_fields.py
index 12c936b22..a538bba4f 100644
--- a/tests/test_fields.py
+++ b/tests/test_fields.py
@@ -18,6 +18,7 @@ from rest_framework import exceptions, serializers
from rest_framework.compat import ProhibitNullCharactersValidator
from rest_framework.fields import DjangoImageField, is_simple_callable
+
try:
import typings
except ImportError:
@@ -27,8 +28,8 @@ except ImportError:
# Tests for helper functions.
# ---------------------------
-class TestIsSimpleCallable:
+class TestIsSimpleCallable:
def test_method(self):
class Foo:
@classmethod
@@ -38,7 +39,7 @@ class TestIsSimpleCallable:
def valid(self):
pass
- def valid_kwargs(self, param='value'):
+ def valid_kwargs(self, param="value"):
pass
def valid_vargs_kwargs(self, *args, **kwargs):
@@ -65,13 +66,13 @@ class TestIsSimpleCallable:
def simple():
pass
- def valid(param='value', param2='value'):
+ def valid(param="value", param2="value"):
pass
def valid_vargs_kwargs(*args, **kwargs):
pass
- def invalid(param, param2='value'):
+ def invalid(param, param2="value"):
pass
assert is_simple_callable(simple)
@@ -84,20 +85,19 @@ class TestIsSimpleCallable:
class ChoiceModel(models.Model):
choice_field = models.CharField(
- max_length=1, default='a',
- choices=(('a', 'A'), ('b', 'B')),
+ max_length=1, default="a", choices=(("a", "A"), ("b", "B"))
)
class Meta:
- app_label = 'tests'
+ app_label = "tests"
assert is_simple_callable(ChoiceModel().get_choice_field_display)
- @unittest.skipUnless(typings, 'requires python 3.5')
+ @unittest.skipUnless(typings, "requires python 3.5")
def test_type_annotation(self):
# The annotation will otherwise raise a syntax error in python < 3.5
exec("def valid(param: str='value'): pass", locals())
- valid = locals()['valid']
+ valid = locals()["valid"]
assert is_simple_callable(valid)
@@ -105,10 +105,12 @@ class TestIsSimpleCallable:
# Tests for field keyword arguments and core functionality.
# ---------------------------------------------------------
+
class TestEmpty:
"""
Tests for `required`, `allow_null`, `allow_blank`, `default`.
"""
+
def test_required(self):
"""
By default a field must be included in the input.
@@ -116,7 +118,7 @@ class TestEmpty:
field = serializers.IntegerField()
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation()
- assert exc_info.value.detail == ['This field is required.']
+ assert exc_info.value.detail == ["This field is required."]
def test_not_required(self):
"""
@@ -133,7 +135,7 @@ class TestEmpty:
field = serializers.IntegerField()
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation(None)
- assert exc_info.value.detail == ['This field may not be null.']
+ assert exc_info.value.detail == ["This field may not be null."]
def test_allow_null(self):
"""
@@ -149,16 +151,16 @@ class TestEmpty:
"""
field = serializers.CharField()
with pytest.raises(serializers.ValidationError) as exc_info:
- field.run_validation('')
- assert exc_info.value.detail == ['This field may not be blank.']
+ field.run_validation("")
+ assert exc_info.value.detail == ["This field may not be blank."]
def test_allow_blank(self):
"""
If `allow_blank=True` then '' is a valid input.
"""
field = serializers.CharField(allow_blank=True)
- output = field.run_validation('')
- assert output == ''
+ output = field.run_validation("")
+ assert output == ""
def test_default(self):
"""
@@ -172,14 +174,16 @@ class TestEmpty:
class TestSource:
def test_source(self):
class ExampleSerializer(serializers.Serializer):
- example_field = serializers.CharField(source='other')
- serializer = ExampleSerializer(data={'example_field': 'abc'})
+ example_field = serializers.CharField(source="other")
+
+ serializer = ExampleSerializer(data={"example_field": "abc"})
assert serializer.is_valid()
- assert serializer.validated_data == {'other': 'abc'}
+ assert serializer.validated_data == {"other": "abc"}
def test_redundant_source(self):
class ExampleSerializer(serializers.Serializer):
- example_field = serializers.CharField(source='example_field')
+ example_field = serializers.CharField(source="example_field")
+
with pytest.raises(AssertionError) as exc_info:
ExampleSerializer().fields
assert str(exc_info.value) == (
@@ -190,28 +194,30 @@ class TestSource:
def test_callable_source(self):
class ExampleSerializer(serializers.Serializer):
- example_field = serializers.CharField(source='example_callable')
+ example_field = serializers.CharField(source="example_callable")
class ExampleInstance(object):
def example_callable(self):
- return 'example callable value'
+ return "example callable value"
serializer = ExampleSerializer(ExampleInstance())
- assert serializer.data['example_field'] == 'example callable value'
+ assert serializer.data["example_field"] == "example callable value"
def test_callable_source_raises(self):
class ExampleSerializer(serializers.Serializer):
- example_field = serializers.CharField(source='example_callable', read_only=True)
+ example_field = serializers.CharField(
+ source="example_callable", read_only=True
+ )
class ExampleInstance(object):
def example_callable(self):
- raise AttributeError('method call failed')
+ raise AttributeError("method call failed")
with pytest.raises(ValueError) as exc_info:
serializer = ExampleSerializer(ExampleInstance())
serializer.data.items()
- assert 'method call failed' in str(exc_info.value)
+ assert "method call failed" in str(exc_info.value)
class TestReadOnly:
@@ -219,6 +225,7 @@ class TestReadOnly:
class TestSerializer(serializers.Serializer):
read_only = serializers.ReadOnlyField(default="789")
writable = serializers.IntegerField()
+
self.Serializer = TestSerializer
def test_writable_fields(self):
@@ -232,18 +239,18 @@ class TestReadOnly:
"""
Read-only serializers.should not be included in validation.
"""
- data = {'read_only': 123, 'writable': 456}
+ data = {"read_only": 123, "writable": 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'writable': 456}
+ assert serializer.validated_data == {"writable": 456}
def test_serialize_read_only(self):
"""
Read-only serializers.should be serialized.
"""
- instance = {'read_only': 123, 'writable': 456}
+ instance = {"read_only": 123, "writable": 456}
serializer = self.Serializer(instance)
- assert serializer.data == {'read_only': 123, 'writable': 456}
+ assert serializer.data == {"read_only": 123, "writable": 456}
class TestWriteOnly:
@@ -251,24 +258,25 @@ class TestWriteOnly:
class TestSerializer(serializers.Serializer):
write_only = serializers.IntegerField(write_only=True)
readable = serializers.IntegerField()
+
self.Serializer = TestSerializer
def test_validate_write_only(self):
"""
Write-only serializers.should be included in validation.
"""
- data = {'write_only': 123, 'readable': 456}
+ data = {"write_only": 123, "readable": 456}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'write_only': 123, 'readable': 456}
+ assert serializer.validated_data == {"write_only": 123, "readable": 456}
def test_serialize_write_only(self):
"""
Write-only serializers.should not be serialized.
"""
- instance = {'write_only': 123, 'readable': 456}
+ instance = {"write_only": 123, "readable": 456}
serializer = self.Serializer(instance)
- assert serializer.data == {'readable': 456}
+ assert serializer.data == {"readable": 456}
class TestInitial:
@@ -276,16 +284,14 @@ class TestInitial:
class TestSerializer(serializers.Serializer):
initial_field = serializers.IntegerField(initial=123)
blank_field = serializers.IntegerField()
+
self.serializer = TestSerializer()
def test_initial(self):
"""
Initial values should be included when serializing a new representation.
"""
- assert self.serializer.data == {
- 'initial_field': 123,
- 'blank_field': None
- }
+ assert self.serializer.data == {"initial_field": 123, "blank_field": None}
class TestInitialWithCallable:
@@ -295,21 +301,21 @@ class TestInitialWithCallable:
class TestSerializer(serializers.Serializer):
initial_field = serializers.IntegerField(initial=initial_value)
+
self.serializer = TestSerializer()
def test_initial_should_accept_callable(self):
"""
Follows the default ``Field.initial`` behaviour where they accept a
callable to produce the initial value"""
- assert self.serializer.data == {
- 'initial_field': 123,
- }
+ assert self.serializer.data == {"initial_field": 123}
class TestLabel:
def setup(self):
class TestSerializer(serializers.Serializer):
- labeled = serializers.IntegerField(label='My label')
+ labeled = serializers.IntegerField(label="My label")
+
self.serializer = TestSerializer()
def test_label(self):
@@ -317,14 +323,15 @@ class TestLabel:
A field's label may be set with the `label` argument.
"""
fields = self.serializer.fields
- assert fields['labeled'].label == 'My label'
+ assert fields["labeled"].label == "My label"
class TestInvalidErrorKey:
def setup(self):
class ExampleField(serializers.Field):
def to_native(self, data):
- self.fail('incorrect')
+ self.fail("incorrect")
+
self.field = ExampleField()
def test_invalid_error_key(self):
@@ -335,8 +342,8 @@ class TestInvalidErrorKey:
with pytest.raises(AssertionError) as exc_info:
self.field.to_native(123)
expected = (
- 'ValidationError raised by `ExampleField`, but error key '
- '`incorrect` does not exist in the `error_messages` dictionary.'
+ "ValidationError raised by `ExampleField`, but error key "
+ "`incorrect` does not exist in the `error_messages` dictionary."
)
assert str(exc_info.value) == expected
@@ -347,72 +354,74 @@ class TestBooleanHTMLInput:
HTML checkboxes do not send any value, but should be treated
as `False` by BooleanField.
"""
+
class TestSerializer(serializers.Serializer):
archived = serializers.BooleanField()
- serializer = TestSerializer(data=QueryDict(''))
+ serializer = TestSerializer(data=QueryDict(""))
assert serializer.is_valid()
- assert serializer.validated_data == {'archived': False}
+ assert serializer.validated_data == {"archived": False}
def test_empty_html_checkbox_not_required(self):
"""
HTML checkboxes do not send any value, but should be treated
as `False` by BooleanField, even if the field is required=False.
"""
+
class TestSerializer(serializers.Serializer):
archived = serializers.BooleanField(required=False)
- serializer = TestSerializer(data=QueryDict(''))
+ serializer = TestSerializer(data=QueryDict(""))
assert serializer.is_valid()
- assert serializer.validated_data == {'archived': False}
+ assert serializer.validated_data == {"archived": False}
class TestHTMLInput:
def test_empty_html_charfield_with_default(self):
class TestSerializer(serializers.Serializer):
- message = serializers.CharField(default='happy')
+ message = serializers.CharField(default="happy")
- serializer = TestSerializer(data=QueryDict(''))
+ serializer = TestSerializer(data=QueryDict(""))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': 'happy'}
+ assert serializer.validated_data == {"message": "happy"}
def test_empty_html_charfield_without_default(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_blank=True)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': ''}
+ assert serializer.validated_data == {"message": ""}
def test_empty_html_charfield_without_default_not_required(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_blank=True, required=False)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': ''}
+ assert serializer.validated_data == {"message": ""}
def test_empty_html_integerfield(self):
class TestSerializer(serializers.Serializer):
message = serializers.IntegerField(default=123)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': 123}
+ assert serializer.validated_data == {"message": 123}
def test_empty_html_uuidfield_with_default(self):
class TestSerializer(serializers.Serializer):
message = serializers.UUIDField(default=uuid.uuid4)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert list(serializer.validated_data) == ['message']
+ assert list(serializer.validated_data) == ["message"]
def test_empty_html_uuidfield_with_optional(self):
class TestSerializer(serializers.Serializer):
message = serializers.UUIDField(required=False)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
assert list(serializer.validated_data) == []
@@ -420,31 +429,31 @@ class TestHTMLInput:
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_null=True)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': None}
+ assert serializer.validated_data == {"message": None}
def test_empty_html_datefield_allow_null(self):
class TestSerializer(serializers.Serializer):
expiry = serializers.DateField(allow_null=True)
- serializer = TestSerializer(data=QueryDict('expiry='))
+ serializer = TestSerializer(data=QueryDict("expiry="))
assert serializer.is_valid()
- assert serializer.validated_data == {'expiry': None}
+ assert serializer.validated_data == {"expiry": None}
def test_empty_html_charfield_allow_null_allow_blank(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(allow_null=True, allow_blank=True)
- serializer = TestSerializer(data=QueryDict('message='))
+ serializer = TestSerializer(data=QueryDict("message="))
assert serializer.is_valid()
- assert serializer.validated_data == {'message': ''}
+ assert serializer.validated_data == {"message": ""}
def test_empty_html_charfield_required_false(self):
class TestSerializer(serializers.Serializer):
message = serializers.CharField(required=False)
- serializer = TestSerializer(data=QueryDict(''))
+ serializer = TestSerializer(data=QueryDict(""))
assert serializer.is_valid()
assert serializer.validated_data == {}
@@ -452,52 +461,55 @@ class TestHTMLInput:
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(child=serializers.IntegerField())
- serializer = TestSerializer(data=QueryDict('scores=1&scores=3'))
+ serializer = TestSerializer(data=QueryDict("scores=1&scores=3"))
assert serializer.is_valid()
- assert serializer.validated_data == {'scores': [1, 3]}
+ assert serializer.validated_data == {"scores": [1, 3]}
def test_querydict_list_input_only_one_input(self):
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(child=serializers.IntegerField())
- serializer = TestSerializer(data=QueryDict('scores=1&'))
+ serializer = TestSerializer(data=QueryDict("scores=1&"))
assert serializer.is_valid()
- assert serializer.validated_data == {'scores': [1]}
+ assert serializer.validated_data == {"scores": [1]}
def test_querydict_list_input_no_values_uses_default(self):
"""
When there are no values passed in, and default is set
The field should return the default value
"""
+
class TestSerializer(serializers.Serializer):
a = serializers.IntegerField(required=True)
scores = serializers.ListField(default=lambda: [1, 3])
- serializer = TestSerializer(data=QueryDict('a=1&'))
+ serializer = TestSerializer(data=QueryDict("a=1&"))
assert serializer.is_valid()
- assert serializer.validated_data == {'a': 1, 'scores': [1, 3]}
+ assert serializer.validated_data == {"a": 1, "scores": [1, 3]}
def test_querydict_list_input_supports_indexed_keys(self):
"""
When data is passed in the format `scores[0]=1&scores[1]=3`
The field should return the correct list, ignoring the default
"""
+
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(default=lambda: [1, 3])
serializer = TestSerializer(data=QueryDict("scores[0]=5&scores[1]=6"))
assert serializer.is_valid()
- assert serializer.validated_data == {'scores': ['5', '6']}
+ assert serializer.validated_data == {"scores": ["5", "6"]}
def test_querydict_list_input_no_values_no_default_and_not_required(self):
"""
When there are no keys passed, there is no default, and required=False
The field should be skipped
"""
+
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(required=False)
- serializer = TestSerializer(data=QueryDict(''))
+ serializer = TestSerializer(data=QueryDict(""))
assert serializer.is_valid()
assert serializer.validated_data == {}
@@ -506,58 +518,60 @@ class TestHTMLInput:
When there are no keys passed, there is no default, and required=False
The field should return an array of 1 item, blank
"""
+
class TestSerializer(serializers.Serializer):
scores = serializers.ListField(required=False)
- serializer = TestSerializer(data=QueryDict('scores=&'))
+ serializer = TestSerializer(data=QueryDict("scores=&"))
assert serializer.is_valid()
- assert serializer.validated_data == {'scores': ['']}
+ assert serializer.validated_data == {"scores": [""]}
class TestCreateOnlyDefault:
def setup(self):
- default = serializers.CreateOnlyDefault('2001-01-01')
+ default = serializers.CreateOnlyDefault("2001-01-01")
class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=default)
text = serializers.CharField()
+
self.Serializer = TestSerializer
def test_create_only_default_is_provided(self):
- serializer = self.Serializer(data={'text': 'example'})
+ serializer = self.Serializer(data={"text": "example"})
assert serializer.is_valid()
assert serializer.validated_data == {
- 'text': 'example', 'published': '2001-01-01'
+ "text": "example",
+ "published": "2001-01-01",
}
def test_create_only_default_is_not_provided_on_update(self):
- instance = {
- 'text': 'example', 'published': '2001-01-01'
- }
- serializer = self.Serializer(instance, data={'text': 'example'})
+ instance = {"text": "example", "published": "2001-01-01"}
+ serializer = self.Serializer(instance, data={"text": "example"})
assert serializer.is_valid()
- assert serializer.validated_data == {
- 'text': 'example',
- }
+ assert serializer.validated_data == {"text": "example"}
def test_create_only_default_callable_sets_context(self):
"""
CreateOnlyDefault instances with a callable default should set_context
on the callable if possible
"""
+
class TestCallableDefault:
def set_context(self, serializer_field):
self.field = serializer_field
def __call__(self):
- return "success" if hasattr(self, 'field') else "failure"
+ return "success" if hasattr(self, "field") else "failure"
class TestSerializer(serializers.Serializer):
- context_set = serializers.CharField(default=serializers.CreateOnlyDefault(TestCallableDefault()))
+ context_set = serializers.CharField(
+ default=serializers.CreateOnlyDefault(TestCallableDefault())
+ )
serializer = TestSerializer(data={})
assert serializer.is_valid()
- assert serializer.validated_data['context_set'] == 'success'
+ assert serializer.validated_data["context_set"] == "success"
class Test5087Regression:
@@ -566,13 +580,14 @@ class Test5087Regression:
field = serializers.CharField()
assert field.root is field
- field.bind('name', parent)
+ field.bind("name", parent)
assert field.root is parent
# Tests for field input and output values.
# ----------------------------------------
+
def get_items(mapping_or_list_of_two_tuples):
# Tests accept either lists of two tuples, or dictionaries.
if isinstance(mapping_or_list_of_two_tuples, dict):
@@ -586,13 +601,15 @@ class FieldValues:
"""
Base class for testing valid and invalid input values.
"""
+
def test_valid_inputs(self):
"""
Ensure that valid values return the expected validated data.
"""
for input_value, expected_output in get_items(self.valid_inputs):
- assert self.field.run_validation(input_value) == expected_output, \
- 'input value: {}'.format(repr(input_value))
+ assert (
+ self.field.run_validation(input_value) == expected_output
+ ), "input value: {}".format(repr(input_value))
def test_invalid_inputs(self):
"""
@@ -601,58 +618,59 @@ class FieldValues:
for input_value, expected_failure in get_items(self.invalid_inputs):
with pytest.raises(serializers.ValidationError) as exc_info:
self.field.run_validation(input_value)
- assert exc_info.value.detail == expected_failure, \
- 'input value: {}'.format(repr(input_value))
+ assert exc_info.value.detail == expected_failure, "input value: {}".format(
+ repr(input_value)
+ )
def test_outputs(self):
for output_value, expected_output in get_items(self.outputs):
- assert self.field.to_representation(output_value) == expected_output, \
- 'output value: {}'.format(repr(output_value))
+ assert (
+ self.field.to_representation(output_value) == expected_output
+ ), "output value: {}".format(repr(output_value))
# Boolean types...
+
class TestBooleanField(FieldValues):
"""
Valid and invalid values for `BooleanField`.
"""
+
valid_inputs = {
- 'true': True,
- 'false': False,
- '1': True,
- '0': False,
+ "true": True,
+ "false": False,
+ "1": True,
+ "0": False,
1: True,
0: False,
True: True,
False: False,
}
invalid_inputs = {
- 'foo': ['Must be a valid boolean.'],
- None: ['This field may not be null.']
+ "foo": ["Must be a valid boolean."],
+ None: ["This field may not be null."],
}
outputs = {
- 'true': True,
- 'false': False,
- '1': True,
- '0': False,
+ "true": True,
+ "false": False,
+ "1": True,
+ "0": False,
1: True,
0: False,
True: True,
False: False,
- 'other': True
+ "other": True,
}
field = serializers.BooleanField()
def test_disallow_unhashable_collection_types(self):
- inputs = (
- [],
- {},
- )
+ inputs = ([], {})
field = self.field
for input_value in inputs:
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation(input_value)
- expected = ['Must be a valid boolean.'.format(input_value)]
+ expected = ["Must be a valid boolean.".format(input_value)]
assert exc_info.value.detail == expected
@@ -660,25 +678,24 @@ class TestNullBooleanField(TestBooleanField):
"""
Valid and invalid values for `NullBooleanField`.
"""
+
valid_inputs = {
- 'true': True,
- 'false': False,
- 'null': None,
- True: True,
- False: False,
- None: None
- }
- invalid_inputs = {
- 'foo': ['Must be a valid boolean.'],
- }
- outputs = {
- 'true': True,
- 'false': False,
- 'null': None,
+ "true": True,
+ "false": False,
+ "null": None,
True: True,
False: False,
None: None,
- 'other': True
+ }
+ invalid_inputs = {"foo": ["Must be a valid boolean."]}
+ outputs = {
+ "true": True,
+ "false": False,
+ "null": None,
+ True: True,
+ False: False,
+ None: None,
+ "other": True,
}
field = serializers.NullBooleanField()
@@ -695,82 +712,81 @@ class TestNullableBooleanField(TestNullBooleanField):
# String types...
+
class TestCharField(FieldValues):
"""
Valid and invalid values for `CharField`.
"""
- valid_inputs = {
- 1: '1',
- 'abc': 'abc'
- }
+
+ valid_inputs = {1: "1", "abc": "abc"}
invalid_inputs = {
- (): ['Not a valid string.'],
- True: ['Not a valid string.'],
- '': ['This field may not be blank.']
- }
- outputs = {
- 1: '1',
- 'abc': 'abc'
+ (): ["Not a valid string."],
+ True: ["Not a valid string."],
+ "": ["This field may not be blank."],
}
+ outputs = {1: "1", "abc": "abc"}
field = serializers.CharField()
def test_trim_whitespace_default(self):
field = serializers.CharField()
- assert field.to_internal_value(' abc ') == 'abc'
+ assert field.to_internal_value(" abc ") == "abc"
def test_trim_whitespace_disabled(self):
field = serializers.CharField(trim_whitespace=False)
- assert field.to_internal_value(' abc ') == ' abc '
+ assert field.to_internal_value(" abc ") == " abc "
def test_disallow_blank_with_trim_whitespace(self):
field = serializers.CharField(allow_blank=False, trim_whitespace=True)
with pytest.raises(serializers.ValidationError) as exc_info:
- field.run_validation(' ')
- assert exc_info.value.detail == ['This field may not be blank.']
+ field.run_validation(" ")
+ assert exc_info.value.detail == ["This field may not be blank."]
- @pytest.mark.skipif(ProhibitNullCharactersValidator is None, reason="Skipped on Django < 2.0")
+ @pytest.mark.skipif(
+ ProhibitNullCharactersValidator is None, reason="Skipped on Django < 2.0"
+ )
def test_null_bytes(self):
field = serializers.CharField()
- for value in ('\0', 'foo\0', '\0foo', 'foo\0foo'):
+ for value in ("\0", "foo\0", "\0foo", "foo\0foo"):
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation(value)
- assert exc_info.value.detail == [
- 'Null characters are not allowed.'
- ]
+ assert exc_info.value.detail == ["Null characters are not allowed."]
def test_iterable_validators(self):
"""
Ensure `validators` parameter is compatible with reasonable iterables.
"""
- value = 'example'
+ value = "example"
for validators in ([], (), set()):
field = serializers.CharField(validators=validators)
field.run_validation(value)
def raise_exception(value):
- raise exceptions.ValidationError('Raised error')
+ raise exceptions.ValidationError("Raised error")
- for validators in ([raise_exception], (raise_exception,), set([raise_exception])):
+ for validators in (
+ [raise_exception],
+ (raise_exception,),
+ set([raise_exception]),
+ ):
field = serializers.CharField(validators=validators)
with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation(value)
- assert exc_info.value.detail == ['Raised error']
+ assert exc_info.value.detail == ["Raised error"]
class TestEmailField(FieldValues):
"""
Valid and invalid values for `EmailField`.
"""
+
valid_inputs = {
- 'example@example.com': 'example@example.com',
- ' example@example.com ': 'example@example.com',
- }
- invalid_inputs = {
- 'examplecom': ['Enter a valid email address.']
+ "example@example.com": "example@example.com",
+ " example@example.com ": "example@example.com",
}
+ invalid_inputs = {"examplecom": ["Enter a valid email address."]}
outputs = {}
field = serializers.EmailField()
@@ -779,39 +795,34 @@ class TestRegexField(FieldValues):
"""
Valid and invalid values for `RegexField`.
"""
- valid_inputs = {
- 'a9': 'a9',
- }
- invalid_inputs = {
- 'A9': ["This value does not match the required pattern."]
- }
+
+ valid_inputs = {"a9": "a9"}
+ invalid_inputs = {"A9": ["This value does not match the required pattern."]}
outputs = {}
- field = serializers.RegexField(regex='[a-z][0-9]')
+ field = serializers.RegexField(regex="[a-z][0-9]")
class TestiCompiledRegexField(FieldValues):
"""
Valid and invalid values for `RegexField`.
"""
- valid_inputs = {
- 'a9': 'a9',
- }
- invalid_inputs = {
- 'A9': ["This value does not match the required pattern."]
- }
+
+ valid_inputs = {"a9": "a9"}
+ invalid_inputs = {"A9": ["This value does not match the required pattern."]}
outputs = {}
- field = serializers.RegexField(regex=re.compile('[a-z][0-9]'))
+ field = serializers.RegexField(regex=re.compile("[a-z][0-9]"))
class TestSlugField(FieldValues):
"""
Valid and invalid values for `SlugField`.
"""
- valid_inputs = {
- 'slug-99': 'slug-99',
- }
+
+ valid_inputs = {"slug-99": "slug-99"}
invalid_inputs = {
- 'slug 99': ['Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.']
+ "slug 99": [
+ 'Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'
+ ]
}
outputs = {}
field = serializers.SlugField()
@@ -821,7 +832,7 @@ class TestSlugField(FieldValues):
validation_error = False
try:
- field.run_validation(u'slug-99-\u0420')
+ field.run_validation(u"slug-99-\u0420")
except serializers.ValidationError:
validation_error = True
@@ -832,12 +843,9 @@ class TestURLField(FieldValues):
"""
Valid and invalid values for `URLField`.
"""
- valid_inputs = {
- 'http://example.com': 'http://example.com',
- }
- invalid_inputs = {
- 'example.com': ['Enter a valid URL.']
- }
+
+ valid_inputs = {"http://example.com": "http://example.com"}
+ invalid_inputs = {"example.com": ["Enter a valid URL."]}
outputs = {}
field = serializers.URLField()
@@ -846,18 +854,29 @@ class TestUUIDField(FieldValues):
"""
Valid and invalid values for `UUIDField`.
"""
+
valid_inputs = {
- '825d7aeb-05a9-45b5-a5b7-05df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'),
- '825d7aeb05a945b5a5b705df87923cda': uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'),
- 'urn:uuid:213b7d9b-244f-410d-828c-dabce7a2615d': uuid.UUID('213b7d9b-244f-410d-828c-dabce7a2615d'),
- 284758210125106368185219588917561929842: uuid.UUID('d63a6fb6-88d5-40c7-a91c-9edf73283072')
+ "825d7aeb-05a9-45b5-a5b7-05df87923cda": uuid.UUID(
+ "825d7aeb-05a9-45b5-a5b7-05df87923cda"
+ ),
+ "825d7aeb05a945b5a5b705df87923cda": uuid.UUID(
+ "825d7aeb-05a9-45b5-a5b7-05df87923cda"
+ ),
+ "urn:uuid:213b7d9b-244f-410d-828c-dabce7a2615d": uuid.UUID(
+ "213b7d9b-244f-410d-828c-dabce7a2615d"
+ ),
+ 284758210125106368185219588917561929842: uuid.UUID(
+ "d63a6fb6-88d5-40c7-a91c-9edf73283072"
+ ),
}
invalid_inputs = {
- '825d7aeb-05a9-45b5-a5b7': ['Must be a valid UUID.'],
- (1, 2, 3): ['Must be a valid UUID.']
+ "825d7aeb-05a9-45b5-a5b7": ["Must be a valid UUID."],
+ (1, 2, 3): ["Must be a valid UUID."],
}
outputs = {
- uuid.UUID('825d7aeb-05a9-45b5-a5b7-05df87923cda'): '825d7aeb-05a9-45b5-a5b7-05df87923cda'
+ uuid.UUID(
+ "825d7aeb-05a9-45b5-a5b7-05df87923cda"
+ ): "825d7aeb-05a9-45b5-a5b7-05df87923cda"
}
field = serializers.UUIDField()
@@ -867,29 +886,32 @@ class TestUUIDField(FieldValues):
assert field.to_internal_value(formatted_uuid_0) == uuid.UUID(int=0)
def test_formats(self):
- self._test_format('int', 0)
- self._test_format('hex_verbose', '00000000-0000-0000-0000-000000000000')
- self._test_format('urn', 'urn:uuid:00000000-0000-0000-0000-000000000000')
- self._test_format('hex', '0' * 32)
+ self._test_format("int", 0)
+ self._test_format("hex_verbose", "00000000-0000-0000-0000-000000000000")
+ self._test_format("urn", "urn:uuid:00000000-0000-0000-0000-000000000000")
+ self._test_format("hex", "0" * 32)
class TestIPAddressField(FieldValues):
"""
Valid and invalid values for `IPAddressField`
"""
+
valid_inputs = {
- '127.0.0.1': '127.0.0.1',
- '192.168.33.255': '192.168.33.255',
- '2001:0db8:85a3:0042:1000:8a2e:0370:7334': '2001:db8:85a3:42:1000:8a2e:370:7334',
- '2001:cdba:0:0:0:0:3257:9652': '2001:cdba::3257:9652',
- '2001:cdba::3257:9652': '2001:cdba::3257:9652'
+ "127.0.0.1": "127.0.0.1",
+ "192.168.33.255": "192.168.33.255",
+ "2001:0db8:85a3:0042:1000:8a2e:0370:7334": "2001:db8:85a3:42:1000:8a2e:370:7334",
+ "2001:cdba:0:0:0:0:3257:9652": "2001:cdba::3257:9652",
+ "2001:cdba::3257:9652": "2001:cdba::3257:9652",
}
invalid_inputs = {
- '127001': ['Enter a valid IPv4 or IPv6 address.'],
- '127.122.111.2231': ['Enter a valid IPv4 or IPv6 address.'],
- '2001:::9652': ['Enter a valid IPv4 or IPv6 address.'],
- '2001:0db8:85a3:0042:1000:8a2e:0370:73341': ['Enter a valid IPv4 or IPv6 address.'],
- 1000: ['Enter a valid IPv4 or IPv6 address.'],
+ "127001": ["Enter a valid IPv4 or IPv6 address."],
+ "127.122.111.2231": ["Enter a valid IPv4 or IPv6 address."],
+ "2001:::9652": ["Enter a valid IPv4 or IPv6 address."],
+ "2001:0db8:85a3:0042:1000:8a2e:0370:73341": [
+ "Enter a valid IPv4 or IPv6 address."
+ ],
+ 1000: ["Enter a valid IPv4 or IPv6 address."],
}
outputs = {}
field = serializers.IPAddressField()
@@ -899,33 +921,34 @@ class TestIPv4AddressField(FieldValues):
"""
Valid and invalid values for `IPAddressField`
"""
- valid_inputs = {
- '127.0.0.1': '127.0.0.1',
- '192.168.33.255': '192.168.33.255',
- }
+
+ valid_inputs = {"127.0.0.1": "127.0.0.1", "192.168.33.255": "192.168.33.255"}
invalid_inputs = {
- '127001': ['Enter a valid IPv4 address.'],
- '127.122.111.2231': ['Enter a valid IPv4 address.'],
+ "127001": ["Enter a valid IPv4 address."],
+ "127.122.111.2231": ["Enter a valid IPv4 address."],
}
outputs = {}
- field = serializers.IPAddressField(protocol='IPv4')
+ field = serializers.IPAddressField(protocol="IPv4")
class TestIPv6AddressField(FieldValues):
"""
Valid and invalid values for `IPAddressField`
"""
+
valid_inputs = {
- '2001:0db8:85a3:0042:1000:8a2e:0370:7334': '2001:db8:85a3:42:1000:8a2e:370:7334',
- '2001:cdba:0:0:0:0:3257:9652': '2001:cdba::3257:9652',
- '2001:cdba::3257:9652': '2001:cdba::3257:9652'
+ "2001:0db8:85a3:0042:1000:8a2e:0370:7334": "2001:db8:85a3:42:1000:8a2e:370:7334",
+ "2001:cdba:0:0:0:0:3257:9652": "2001:cdba::3257:9652",
+ "2001:cdba::3257:9652": "2001:cdba::3257:9652",
}
invalid_inputs = {
- '2001:::9652': ['Enter a valid IPv4 or IPv6 address.'],
- '2001:0db8:85a3:0042:1000:8a2e:0370:73341': ['Enter a valid IPv4 or IPv6 address.'],
+ "2001:::9652": ["Enter a valid IPv4 or IPv6 address."],
+ "2001:0db8:85a3:0042:1000:8a2e:0370:73341": [
+ "Enter a valid IPv4 or IPv6 address."
+ ],
}
outputs = {}
- field = serializers.IPAddressField(protocol='IPv6')
+ field = serializers.IPAddressField(protocol="IPv6")
class TestFilePathField(FieldValues):
@@ -933,47 +956,27 @@ class TestFilePathField(FieldValues):
Valid and invalid values for `FilePathField`
"""
- valid_inputs = {
- __file__: __file__,
- }
- invalid_inputs = {
- 'wrong_path': ['"wrong_path" is not a valid path choice.']
- }
- outputs = {
- }
- field = serializers.FilePathField(
- path=os.path.abspath(os.path.dirname(__file__))
- )
+ valid_inputs = {__file__: __file__}
+ invalid_inputs = {"wrong_path": ['"wrong_path" is not a valid path choice.']}
+ outputs = {}
+ field = serializers.FilePathField(path=os.path.abspath(os.path.dirname(__file__)))
# Number types...
+
class TestIntegerField(FieldValues):
"""
Valid and invalid values for `IntegerField`.
"""
- valid_inputs = {
- '1': 1,
- '0': 0,
- 1: 1,
- 0: 0,
- 1.0: 1,
- 0.0: 0,
- '1.0': 1
- }
+
+ valid_inputs = {"1": 1, "0": 0, 1: 1, 0: 0, 1.0: 1, 0.0: 0, "1.0": 1}
invalid_inputs = {
- 0.5: ['A valid integer is required.'],
- 'abc': ['A valid integer is required.'],
- '0.5': ['A valid integer is required.']
- }
- outputs = {
- '1': 1,
- '0': 0,
- 1: 1,
- 0: 0,
- 1.0: 1,
- 0.0: 0
+ 0.5: ["A valid integer is required."],
+ "abc": ["A valid integer is required."],
+ "0.5": ["A valid integer is required."],
}
+ outputs = {"1": 1, "0": 0, 1: 1, 0: 0, 1.0: 1, 0.0: 0}
field = serializers.IntegerField()
@@ -981,17 +984,13 @@ class TestMinMaxIntegerField(FieldValues):
"""
Valid and invalid values for `IntegerField` with min and max limits.
"""
- valid_inputs = {
- '1': 1,
- '3': 3,
- 1: 1,
- 3: 3,
- }
+
+ valid_inputs = {"1": 1, "3": 3, 1: 1, 3: 3}
invalid_inputs = {
- 0: ['Ensure this value is greater than or equal to 1.'],
- 4: ['Ensure this value is less than or equal to 3.'],
- '0': ['Ensure this value is greater than or equal to 1.'],
- '4': ['Ensure this value is less than or equal to 3.'],
+ 0: ["Ensure this value is greater than or equal to 1."],
+ 4: ["Ensure this value is less than or equal to 3."],
+ "0": ["Ensure this value is greater than or equal to 1."],
+ "4": ["Ensure this value is less than or equal to 3."],
}
outputs = {}
field = serializers.IntegerField(min_value=1, max_value=3)
@@ -1001,25 +1000,10 @@ class TestFloatField(FieldValues):
"""
Valid and invalid values for `FloatField`.
"""
- valid_inputs = {
- '1': 1.0,
- '0': 0.0,
- 1: 1.0,
- 0: 0.0,
- 1.0: 1.0,
- 0.0: 0.0,
- }
- invalid_inputs = {
- 'abc': ["A valid number is required."]
- }
- outputs = {
- '1': 1.0,
- '0': 0.0,
- 1: 1.0,
- 0: 0.0,
- 1.0: 1.0,
- 0.0: 0.0,
- }
+
+ valid_inputs = {"1": 1.0, "0": 0.0, 1: 1.0, 0: 0.0, 1.0: 1.0, 0.0: 0.0}
+ invalid_inputs = {"abc": ["A valid number is required."]}
+ outputs = {"1": 1.0, "0": 0.0, 1: 1.0, 0: 0.0, 1.0: 1.0, 0.0: 0.0}
field = serializers.FloatField()
@@ -1027,19 +1011,13 @@ class TestMinMaxFloatField(FieldValues):
"""
Valid and invalid values for `FloatField` with min and max limits.
"""
- valid_inputs = {
- '1': 1,
- '3': 3,
- 1: 1,
- 3: 3,
- 1.0: 1.0,
- 3.0: 3.0,
- }
+
+ valid_inputs = {"1": 1, "3": 3, 1: 1, 3: 3, 1.0: 1.0, 3.0: 3.0}
invalid_inputs = {
- 0.9: ['Ensure this value is greater than or equal to 1.'],
- 3.1: ['Ensure this value is less than or equal to 3.'],
- '0.0': ['Ensure this value is greater than or equal to 1.'],
- '3.1': ['Ensure this value is less than or equal to 3.'],
+ 0.9: ["Ensure this value is greater than or equal to 1."],
+ 3.1: ["Ensure this value is less than or equal to 3."],
+ "0.0": ["Ensure this value is greater than or equal to 1."],
+ "3.1": ["Ensure this value is less than or equal to 3."],
}
outputs = {}
field = serializers.FloatField(min_value=1, max_value=3)
@@ -1049,36 +1027,43 @@ class TestDecimalField(FieldValues):
"""
Valid and invalid values for `DecimalField`.
"""
+
valid_inputs = {
- '12.3': Decimal('12.3'),
- '0.1': Decimal('0.1'),
- 10: Decimal('10'),
- 0: Decimal('0'),
- 12.3: Decimal('12.3'),
- 0.1: Decimal('0.1'),
- '2E+1': Decimal('20'),
+ "12.3": Decimal("12.3"),
+ "0.1": Decimal("0.1"),
+ 10: Decimal("10"),
+ 0: Decimal("0"),
+ 12.3: Decimal("12.3"),
+ 0.1: Decimal("0.1"),
+ "2E+1": Decimal("20"),
}
invalid_inputs = (
- ('abc', ["A valid number is required."]),
- (Decimal('Nan'), ["A valid number is required."]),
- (Decimal('Inf'), ["A valid number is required."]),
- ('12.345', ["Ensure that there are no more than 3 digits in total."]),
+ ("abc", ["A valid number is required."]),
+ (Decimal("Nan"), ["A valid number is required."]),
+ (Decimal("Inf"), ["A valid number is required."]),
+ ("12.345", ["Ensure that there are no more than 3 digits in total."]),
(200000000000.0, ["Ensure that there are no more than 3 digits in total."]),
- ('0.01', ["Ensure that there are no more than 1 decimal places."]),
- (123, ["Ensure that there are no more than 2 digits before the decimal point."]),
- ('2E+2', ["Ensure that there are no more than 2 digits before the decimal point."])
+ ("0.01", ["Ensure that there are no more than 1 decimal places."]),
+ (
+ 123,
+ ["Ensure that there are no more than 2 digits before the decimal point."],
+ ),
+ (
+ "2E+2",
+ ["Ensure that there are no more than 2 digits before the decimal point."],
+ ),
)
outputs = {
- '1': '1.0',
- '0': '0.0',
- '1.09': '1.1',
- '0.04': '0.0',
- 1: '1.0',
- 0: '0.0',
- Decimal('1.0'): '1.0',
- Decimal('0.0'): '0.0',
- Decimal('1.09'): '1.1',
- Decimal('0.04'): '0.0'
+ "1": "1.0",
+ "0": "0.0",
+ "1.09": "1.1",
+ "0.04": "0.0",
+ 1: "1.0",
+ 0: "0.0",
+ Decimal("1.0"): "1.0",
+ Decimal("0.0"): "0.0",
+ Decimal("1.09"): "1.1",
+ Decimal("0.04"): "0.0",
}
field = serializers.DecimalField(max_digits=3, decimal_places=1)
@@ -1087,29 +1072,23 @@ class TestMinMaxDecimalField(FieldValues):
"""
Valid and invalid values for `DecimalField` with min and max limits.
"""
- valid_inputs = {
- '10.0': Decimal('10.0'),
- '20.0': Decimal('20.0'),
- }
+
+ valid_inputs = {"10.0": Decimal("10.0"), "20.0": Decimal("20.0")}
invalid_inputs = {
- '9.9': ['Ensure this value is greater than or equal to 10.'],
- '20.1': ['Ensure this value is less than or equal to 20.'],
+ "9.9": ["Ensure this value is greater than or equal to 10."],
+ "20.1": ["Ensure this value is less than or equal to 20."],
}
outputs = {}
field = serializers.DecimalField(
- max_digits=3, decimal_places=1,
- min_value=10, max_value=20
+ max_digits=3, decimal_places=1, min_value=10, max_value=20
)
class TestNoMaxDigitsDecimalField(FieldValues):
field = serializers.DecimalField(
- max_value=100, min_value=0,
- decimal_places=2, max_digits=None
+ max_value=100, min_value=0, decimal_places=2, max_digits=None
)
- valid_inputs = {
- '10': Decimal('10.00')
- }
+ valid_inputs = {"10": Decimal("10.00")}
invalid_inputs = {}
outputs = {}
@@ -1118,36 +1097,38 @@ class TestNoStringCoercionDecimalField(FieldValues):
"""
Output values for `DecimalField` with `coerce_to_string=False`.
"""
+
valid_inputs = {}
invalid_inputs = {}
outputs = {
- 1.09: Decimal('1.1'),
- 0.04: Decimal('0.0'),
- '1.09': Decimal('1.1'),
- '0.04': Decimal('0.0'),
- Decimal('1.09'): Decimal('1.1'),
- Decimal('0.04'): Decimal('0.0'),
+ 1.09: Decimal("1.1"),
+ 0.04: Decimal("0.0"),
+ "1.09": Decimal("1.1"),
+ "0.04": Decimal("0.0"),
+ Decimal("1.09"): Decimal("1.1"),
+ Decimal("0.04"): Decimal("0.0"),
}
field = serializers.DecimalField(
- max_digits=3, decimal_places=1,
- coerce_to_string=False
+ max_digits=3, decimal_places=1, coerce_to_string=False
)
class TestLocalizedDecimalField(TestCase):
- @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
+ @override_settings(USE_L10N=True, LANGUAGE_CODE="pl")
def test_to_internal_value(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
- assert field.to_internal_value('1,1') == Decimal('1.1')
+ assert field.to_internal_value("1,1") == Decimal("1.1")
- @override_settings(USE_L10N=True, LANGUAGE_CODE='pl')
+ @override_settings(USE_L10N=True, LANGUAGE_CODE="pl")
def test_to_representation(self):
field = serializers.DecimalField(max_digits=2, decimal_places=1, localize=True)
- assert field.to_representation(Decimal('1.1')) == '1,1'
+ assert field.to_representation(Decimal("1.1")) == "1,1"
def test_localize_forces_coerce_to_string(self):
- field = serializers.DecimalField(max_digits=2, decimal_places=1, coerce_to_string=False, localize=True)
- assert isinstance(field.to_representation(Decimal('1.1')), six.string_types)
+ field = serializers.DecimalField(
+ max_digits=2, decimal_places=1, coerce_to_string=False, localize=True
+ )
+ assert isinstance(field.to_representation(Decimal("1.1")), six.string_types)
class TestQuantizedValueForDecimal(TestCase):
@@ -1159,44 +1140,44 @@ class TestQuantizedValueForDecimal(TestCase):
def test_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
- value = field.to_internal_value('12').as_tuple()
+ value = field.to_internal_value("12").as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
def test_part_precision_string_quantized_value_for_decimal(self):
field = serializers.DecimalField(max_digits=4, decimal_places=2)
- value = field.to_internal_value('12.0').as_tuple()
+ value = field.to_internal_value("12.0").as_tuple()
expected_digit_tuple = (0, (1, 2, 0, 0), -2)
assert value == expected_digit_tuple
class TestNoDecimalPlaces(FieldValues):
- valid_inputs = {
- '0.12345': Decimal('0.12345'),
- }
+ valid_inputs = {"0.12345": Decimal("0.12345")}
invalid_inputs = {
- '0.1234567': ['Ensure that there are no more than 6 digits in total.']
- }
- outputs = {
- '1.2345': '1.2345',
- '0': '0',
- '1.1': '1.1',
+ "0.1234567": ["Ensure that there are no more than 6 digits in total."]
}
+ outputs = {"1.2345": "1.2345", "0": "0", "1.1": "1.1"}
field = serializers.DecimalField(max_digits=6, decimal_places=None)
class TestRoundingDecimalField(TestCase):
def test_valid_rounding(self):
- field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_UP)
- assert field.to_representation(Decimal('1.234')) == '1.24'
+ field = serializers.DecimalField(
+ max_digits=4, decimal_places=2, rounding=ROUND_UP
+ )
+ assert field.to_representation(Decimal("1.234")) == "1.24"
- field = serializers.DecimalField(max_digits=4, decimal_places=2, rounding=ROUND_DOWN)
- assert field.to_representation(Decimal('1.234')) == '1.23'
+ field = serializers.DecimalField(
+ max_digits=4, decimal_places=2, rounding=ROUND_DOWN
+ )
+ assert field.to_representation(Decimal("1.234")) == "1.23"
def test_invalid_rounding(self):
with pytest.raises(AssertionError) as excinfo:
- serializers.DecimalField(max_digits=1, decimal_places=1, rounding='ROUND_UNKNOWN')
- assert 'Invalid rounding option' in str(excinfo.value)
+ serializers.DecimalField(
+ max_digits=1, decimal_places=1, rounding="ROUND_UNKNOWN"
+ )
+ assert "Invalid rounding option" in str(excinfo.value)
# Date & time serializers...
@@ -1204,23 +1185,30 @@ class TestDateField(FieldValues):
"""
Valid and invalid values for `DateField`.
"""
+
valid_inputs = {
- '2001-01-01': datetime.date(2001, 1, 1),
+ "2001-01-01": datetime.date(2001, 1, 1),
datetime.date(2001, 1, 1): datetime.date(2001, 1, 1),
}
invalid_inputs = {
- 'abc': ['Date has wrong format. Use one of these formats instead: YYYY-MM-DD.'],
- '2001-99-99': ['Date has wrong format. Use one of these formats instead: YYYY-MM-DD.'],
- '2001-01': ['Date has wrong format. Use one of these formats instead: YYYY-MM-DD.'],
- '2001': ['Date has wrong format. Use one of these formats instead: YYYY-MM-DD.'],
- datetime.datetime(2001, 1, 1, 12, 00): ['Expected a date but got a datetime.'],
+ "abc": ["Date has wrong format. Use one of these formats instead: YYYY-MM-DD."],
+ "2001-99-99": [
+ "Date has wrong format. Use one of these formats instead: YYYY-MM-DD."
+ ],
+ "2001-01": [
+ "Date has wrong format. Use one of these formats instead: YYYY-MM-DD."
+ ],
+ "2001": [
+ "Date has wrong format. Use one of these formats instead: YYYY-MM-DD."
+ ],
+ datetime.datetime(2001, 1, 1, 12, 00): ["Expected a date but got a datetime."],
}
outputs = {
- datetime.date(2001, 1, 1): '2001-01-01',
- '2001-01-01': '2001-01-01',
- six.text_type('2016-01-10'): '2016-01-10',
+ datetime.date(2001, 1, 1): "2001-01-01",
+ "2001-01-01": "2001-01-01",
+ six.text_type("2016-01-10"): "2016-01-10",
None: None,
- '': None,
+ "": None,
}
field = serializers.DateField()
@@ -1229,37 +1217,36 @@ class TestCustomInputFormatDateField(FieldValues):
"""
Valid and invalid values for `DateField` with a custom input format.
"""
- valid_inputs = {
- '1 Jan 2001': datetime.date(2001, 1, 1),
- }
+
+ valid_inputs = {"1 Jan 2001": datetime.date(2001, 1, 1)}
invalid_inputs = {
- '2001-01-01': ['Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY.']
+ "2001-01-01": [
+ "Date has wrong format. Use one of these formats instead: DD [Jan-Dec] YYYY."
+ ]
}
outputs = {}
- field = serializers.DateField(input_formats=['%d %b %Y'])
+ field = serializers.DateField(input_formats=["%d %b %Y"])
class TestCustomOutputFormatDateField(FieldValues):
"""
Values for `DateField` with a custom output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = {
- datetime.date(2001, 1, 1): '01 Jan 2001'
- }
- field = serializers.DateField(format='%d %b %Y')
+ outputs = {datetime.date(2001, 1, 1): "01 Jan 2001"}
+ field = serializers.DateField(format="%d %b %Y")
class TestNoOutputFormatDateField(FieldValues):
"""
Values for `DateField` with no output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = {
- datetime.date(2001, 1, 1): datetime.date(2001, 1, 1)
- }
+ outputs = {datetime.date(2001, 1, 1): datetime.date(2001, 1, 1)}
field = serializers.DateField(format=None)
@@ -1267,27 +1254,38 @@ class TestDateTimeField(FieldValues):
"""
Valid and invalid values for `DateTimeField`.
"""
+
valid_inputs = {
- '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
- '2001-01-01T13:00': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
- '2001-01-01T13:00Z': datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
- datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
- datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
+ "2001-01-01 13:00": datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
+ "2001-01-01T13:00": datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
+ "2001-01-01T13:00Z": datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc),
+ datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(
+ 2001, 1, 1, 13, 00, tzinfo=utc
+ ),
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(
+ 2001, 1, 1, 13, 00, tzinfo=utc
+ ),
}
invalid_inputs = {
- 'abc': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
- '2001-99-99T99:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
- '2018-08-16 22:00-24:00': ['Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z].'],
- datetime.date(2001, 1, 1): ['Expected a datetime but got a date.'],
- '9999-12-31T21:59:59.99990-03:00': ['Datetime value out of range.'],
+ "abc": [
+ "Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
+ ],
+ "2001-99-99T99:00": [
+ "Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
+ ],
+ "2018-08-16 22:00-24:00": [
+ "Datetime has wrong format. Use one of these formats instead: YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]."
+ ],
+ datetime.date(2001, 1, 1): ["Expected a datetime but got a date."],
+ "9999-12-31T21:59:59.99990-03:00": ["Datetime value out of range."],
}
outputs = {
- datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00Z',
- datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): '2001-01-01T13:00:00Z',
- '2001-01-01T00:00:00': '2001-01-01T00:00:00',
- six.text_type('2016-01-10T00:00:00'): '2016-01-10T00:00:00',
+ datetime.datetime(2001, 1, 1, 13, 00): "2001-01-01T13:00:00Z",
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): "2001-01-01T13:00:00Z",
+ "2001-01-01T00:00:00": "2001-01-01T00:00:00",
+ six.text_type("2016-01-10T00:00:00"): "2016-01-10T00:00:00",
None: None,
- '': None,
+ "": None,
}
field = serializers.DateTimeField(default_timezone=utc)
@@ -1296,36 +1294,41 @@ class TestCustomInputFormatDateTimeField(FieldValues):
"""
Valid and invalid values for `DateTimeField` with a custom input format.
"""
+
valid_inputs = {
- '1:35pm, 1 Jan 2001': datetime.datetime(2001, 1, 1, 13, 35, tzinfo=utc),
+ "1:35pm, 1 Jan 2001": datetime.datetime(2001, 1, 1, 13, 35, tzinfo=utc)
}
invalid_inputs = {
- '2001-01-01T20:50': ['Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY.']
+ "2001-01-01T20:50": [
+ "Datetime has wrong format. Use one of these formats instead: hh:mm[AM|PM], DD [Jan-Dec] YYYY."
+ ]
}
outputs = {}
- field = serializers.DateTimeField(default_timezone=utc, input_formats=['%I:%M%p, %d %b %Y'])
+ field = serializers.DateTimeField(
+ default_timezone=utc, input_formats=["%I:%M%p, %d %b %Y"]
+ )
class TestCustomOutputFormatDateTimeField(FieldValues):
"""
Values for `DateTimeField` with a custom output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = {
- datetime.datetime(2001, 1, 1, 13, 00): '01:00PM, 01 Jan 2001',
- }
- field = serializers.DateTimeField(format='%I:%M%p, %d %b %Y')
+ outputs = {datetime.datetime(2001, 1, 1, 13, 00): "01:00PM, 01 Jan 2001"}
+ field = serializers.DateTimeField(format="%I:%M%p, %d %b %Y")
class TestNoOutputFormatDateTimeField(FieldValues):
"""
Values for `DateTimeField` with no output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
outputs = {
- datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00),
+ datetime.datetime(2001, 1, 1, 13, 00): datetime.datetime(2001, 1, 1, 13, 00)
}
field = serializers.DateTimeField(format=None)
@@ -1334,14 +1337,17 @@ class TestNaiveDateTimeField(FieldValues):
"""
Valid and invalid values for `DateTimeField` with naive datetimes.
"""
+
valid_inputs = {
- datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(2001, 1, 1, 13, 00),
- '2001-01-01 13:00': datetime.datetime(2001, 1, 1, 13, 00),
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): datetime.datetime(
+ 2001, 1, 1, 13, 00
+ ),
+ "2001-01-01 13:00": datetime.datetime(2001, 1, 1, 13, 00),
}
invalid_inputs = {}
outputs = {
- datetime.datetime(2001, 1, 1, 13, 00): '2001-01-01T13:00:00',
- datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): '2001-01-01T13:00:00',
+ datetime.datetime(2001, 1, 1, 13, 00): "2001-01-01T13:00:00",
+ datetime.datetime(2001, 1, 1, 13, 00, tzinfo=utc): "2001-01-01T13:00:00",
}
field = serializers.DateTimeField(default_timezone=None)
@@ -1350,25 +1356,34 @@ class TestTZWithDateTimeField(FieldValues):
"""
Valid and invalid values for `DateTimeField` when not using UTC as the timezone.
"""
+
@classmethod
def setup_class(cls):
# use class setup method, as class-level attribute will still be evaluated even if test is skipped
- kolkata = pytz.timezone('Asia/Kolkata')
+ kolkata = pytz.timezone("Asia/Kolkata")
cls.valid_inputs = {
- '2016-12-19T10:00:00': kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
- '2016-12-19T10:00:00+05:30': kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
- datetime.datetime(2016, 12, 19, 10): kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
+ "2016-12-19T10:00:00": kolkata.localize(
+ datetime.datetime(2016, 12, 19, 10)
+ ),
+ "2016-12-19T10:00:00+05:30": kolkata.localize(
+ datetime.datetime(2016, 12, 19, 10)
+ ),
+ datetime.datetime(2016, 12, 19, 10): kolkata.localize(
+ datetime.datetime(2016, 12, 19, 10)
+ ),
}
cls.invalid_inputs = {}
cls.outputs = {
- datetime.datetime(2016, 12, 19, 10): '2016-12-19T10:00:00+05:30',
- datetime.datetime(2016, 12, 19, 4, 30, tzinfo=utc): '2016-12-19T10:00:00+05:30',
+ datetime.datetime(2016, 12, 19, 10): "2016-12-19T10:00:00+05:30",
+ datetime.datetime(
+ 2016, 12, 19, 4, 30, tzinfo=utc
+ ): "2016-12-19T10:00:00+05:30",
}
cls.field = serializers.DateTimeField(default_timezone=kolkata)
-@override_settings(TIME_ZONE='UTC', USE_TZ=True)
+@override_settings(TIME_ZONE="UTC", USE_TZ=True)
class TestDefaultTZDateTimeField(TestCase):
"""
Test the current/default timezone handling in `DateTimeField`.
@@ -1377,7 +1392,7 @@ class TestDefaultTZDateTimeField(TestCase):
@classmethod
def setup_class(cls):
cls.field = serializers.DateTimeField()
- cls.kolkata = pytz.timezone('Asia/Kolkata')
+ cls.kolkata = pytz.timezone("Asia/Kolkata")
def test_default_timezone(self):
assert self.field.default_timezone() == utc
@@ -1390,23 +1405,26 @@ class TestDefaultTZDateTimeField(TestCase):
assert self.field.default_timezone() == utc
-@pytest.mark.skipif(pytz is None, reason='pytz not installed')
-@override_settings(TIME_ZONE='UTC', USE_TZ=True)
+@pytest.mark.skipif(pytz is None, reason="pytz not installed")
+@override_settings(TIME_ZONE="UTC", USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
-
@classmethod
def setup_class(cls):
- cls.kolkata = pytz.timezone('Asia/Kolkata')
- cls.date_format = '%d/%m/%Y %H:%M'
+ cls.kolkata = pytz.timezone("Asia/Kolkata")
+ cls.date_format = "%d/%m/%Y %H:%M"
def test_should_render_date_time_in_default_timezone(self):
- field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
+ field = serializers.DateTimeField(
+ default_timezone=self.kolkata, format=self.date_format
+ )
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
with override(self.kolkata):
rendered_date = field.to_representation(dt)
- rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(self.date_format)
+ rendered_date_in_timezone = dt.astimezone(self.kolkata).strftime(
+ self.date_format
+ )
assert rendered_date == rendered_date_in_timezone
@@ -1417,10 +1435,15 @@ class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
Timezone America/New_York has DST shift from 2017-03-12T02:00:00 to 2017-03-12T03:00:00 and
from 2017-11-05T02:00:00 to 2017-11-05T01:00:00 in 2017.
"""
+
valid_inputs = {}
invalid_inputs = {
- '2017-03-12T02:30:00': ['Invalid datetime for the timezone "America/New_York".'],
- '2017-11-05T01:30:00': ['Invalid datetime for the timezone "America/New_York".']
+ "2017-03-12T02:30:00": [
+ 'Invalid datetime for the timezone "America/New_York".'
+ ],
+ "2017-11-05T01:30:00": [
+ 'Invalid datetime for the timezone "America/New_York".'
+ ],
}
outputs = {}
@@ -1430,7 +1453,7 @@ class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
raise pytz.InvalidTimeError()
def __str__(self):
- return 'America/New_York'
+ return "America/New_York"
field = serializers.DateTimeField(default_timezone=MockTimezone())
@@ -1439,20 +1462,25 @@ class TestTimeField(FieldValues):
"""
Valid and invalid values for `TimeField`.
"""
+
valid_inputs = {
- '13:00': datetime.time(13, 00),
+ "13:00": datetime.time(13, 00),
datetime.time(13, 00): datetime.time(13, 00),
}
invalid_inputs = {
- 'abc': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
- '99:99': ['Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]].'],
+ "abc": [
+ "Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]."
+ ],
+ "99:99": [
+ "Time has wrong format. Use one of these formats instead: hh:mm[:ss[.uuuuuu]]."
+ ],
}
outputs = {
- datetime.time(13, 0): '13:00:00',
- datetime.time(0, 0): '00:00:00',
- '00:00:00': '00:00:00',
+ datetime.time(13, 0): "13:00:00",
+ datetime.time(0, 0): "00:00:00",
+ "00:00:00": "00:00:00",
None: None,
- '': None,
+ "": None,
}
field = serializers.TimeField()
@@ -1461,37 +1489,36 @@ class TestCustomInputFormatTimeField(FieldValues):
"""
Valid and invalid values for `TimeField` with a custom input format.
"""
- valid_inputs = {
- '1:00pm': datetime.time(13, 00),
- }
+
+ valid_inputs = {"1:00pm": datetime.time(13, 00)}
invalid_inputs = {
- '13:00': ['Time has wrong format. Use one of these formats instead: hh:mm[AM|PM].'],
+ "13:00": [
+ "Time has wrong format. Use one of these formats instead: hh:mm[AM|PM]."
+ ]
}
outputs = {}
- field = serializers.TimeField(input_formats=['%I:%M%p'])
+ field = serializers.TimeField(input_formats=["%I:%M%p"])
class TestCustomOutputFormatTimeField(FieldValues):
"""
Values for `TimeField` with a custom output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = {
- datetime.time(13, 00): '01:00PM'
- }
- field = serializers.TimeField(format='%I:%M%p')
+ outputs = {datetime.time(13, 00): "01:00PM"}
+ field = serializers.TimeField(format="%I:%M%p")
class TestNoOutputFormatTimeField(FieldValues):
"""
Values for `TimeField` with a no output format.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = {
- datetime.time(13, 00): datetime.time(13, 00)
- }
+ outputs = {datetime.time(13, 00): datetime.time(13, 00)}
field = serializers.TimeField(format=None)
@@ -1499,64 +1526,74 @@ class TestMinMaxDurationField(FieldValues):
"""
Valid and invalid values for `DurationField` with min and max limits.
"""
+
valid_inputs = {
- '3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
+ "3 08:32:01.000123": datetime.timedelta(
+ days=3, hours=8, minutes=32, seconds=1, microseconds=123
+ ),
86401: datetime.timedelta(days=1, seconds=1),
}
invalid_inputs = {
- 3600: ['Ensure this value is greater than or equal to 1 day, 0:00:00.'],
- '4 08:32:01.000123': ['Ensure this value is less than or equal to 4 days, 0:00:00.'],
- '3600': ['Ensure this value is greater than or equal to 1 day, 0:00:00.'],
+ 3600: ["Ensure this value is greater than or equal to 1 day, 0:00:00."],
+ "4 08:32:01.000123": [
+ "Ensure this value is less than or equal to 4 days, 0:00:00."
+ ],
+ "3600": ["Ensure this value is greater than or equal to 1 day, 0:00:00."],
}
outputs = {}
- field = serializers.DurationField(min_value=datetime.timedelta(days=1), max_value=datetime.timedelta(days=4))
+ field = serializers.DurationField(
+ min_value=datetime.timedelta(days=1), max_value=datetime.timedelta(days=4)
+ )
class TestDurationField(FieldValues):
"""
Valid and invalid values for `DurationField`.
"""
+
valid_inputs = {
- '13': datetime.timedelta(seconds=13),
- '3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
- '08:01': datetime.timedelta(minutes=8, seconds=1),
- datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
+ "13": datetime.timedelta(seconds=13),
+ "3 08:32:01.000123": datetime.timedelta(
+ days=3, hours=8, minutes=32, seconds=1, microseconds=123
+ ),
+ "08:01": datetime.timedelta(minutes=8, seconds=1),
+ datetime.timedelta(
+ days=3, hours=8, minutes=32, seconds=1, microseconds=123
+ ): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
3600: datetime.timedelta(hours=1),
}
invalid_inputs = {
- 'abc': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
- '3 08:32 01.123': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
+ "abc": [
+ "Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu]."
+ ],
+ "3 08:32 01.123": [
+ "Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu]."
+ ],
}
outputs = {
- datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): '3 08:32:01.000123',
+ datetime.timedelta(
+ days=3, hours=8, minutes=32, seconds=1, microseconds=123
+ ): "3 08:32:01.000123"
}
field = serializers.DurationField()
# Choice types...
+
class TestChoiceField(FieldValues):
"""
Valid and invalid values for `ChoiceField`.
"""
- valid_inputs = {
- 'poor': 'poor',
- 'medium': 'medium',
- 'good': 'good',
- }
- invalid_inputs = {
- 'amazing': ['"amazing" is not a valid choice.']
- }
- outputs = {
- 'good': 'good',
- '': '',
- 'amazing': 'amazing',
- }
+
+ valid_inputs = {"poor": "poor", "medium": "medium", "good": "good"}
+ invalid_inputs = {"amazing": ['"amazing" is not a valid choice.']}
+ outputs = {"good": "good", "": "", "amazing": "amazing"}
field = serializers.ChoiceField(
choices=[
- ('poor', 'Poor quality'),
- ('medium', 'Medium quality'),
- ('good', 'Good quality'),
+ ("poor", "Poor quality"),
+ ("medium", "Medium quality"),
+ ("good", "Good quality"),
]
)
@@ -1567,26 +1604,21 @@ class TestChoiceField(FieldValues):
field = serializers.ChoiceField(
allow_blank=True,
choices=[
- ('poor', 'Poor quality'),
- ('medium', 'Medium quality'),
- ('good', 'Good quality'),
- ]
+ ("poor", "Poor quality"),
+ ("medium", "Medium quality"),
+ ("good", "Good quality"),
+ ],
)
- output = field.run_validation('')
- assert output == ''
+ output = field.run_validation("")
+ assert output == ""
def test_allow_null(self):
"""
If `allow_null=True` then '' on HTML forms is treated as None.
"""
- field = serializers.ChoiceField(
- allow_null=True,
- choices=[
- 1, 2, 3
- ]
- )
- field.field_name = 'example'
- value = field.get_value(QueryDict('example='))
+ field = serializers.ChoiceField(allow_null=True, choices=[1, 2, 3])
+ field.field_name = "example"
+ value = field.get_value(QueryDict("example="))
assert value is None
output = field.run_validation(None)
assert output is None
@@ -1597,35 +1629,30 @@ class TestChoiceField(FieldValues):
"""
field = serializers.ChoiceField(
choices=[
- ('Numbers', ['integer', 'float']),
- ('Strings', ['text', 'email', 'url']),
- 'boolean'
+ ("Numbers", ["integer", "float"]),
+ ("Strings", ["text", "email", "url"]),
+ "boolean",
]
)
items = list(field.iter_options())
assert items[0].start_option_group
- assert items[0].label == 'Numbers'
- assert items[1].value == 'integer'
- assert items[2].value == 'float'
+ assert items[0].label == "Numbers"
+ assert items[1].value == "integer"
+ assert items[2].value == "float"
assert items[3].end_option_group
assert items[4].start_option_group
- assert items[4].label == 'Strings'
- assert items[5].value == 'text'
- assert items[6].value == 'email'
- assert items[7].value == 'url'
+ assert items[4].label == "Strings"
+ assert items[5].value == "text"
+ assert items[6].value == "email"
+ assert items[7].value == "url"
assert items[8].end_option_group
- assert items[9].value == 'boolean'
+ assert items[9].value == "boolean"
def test_edit_choices(self):
- field = serializers.ChoiceField(
- allow_null=True,
- choices=[
- 1, 2,
- ]
- )
+ field = serializers.ChoiceField(allow_null=True, choices=[1, 2])
field.choices = [1]
assert field.run_validation(1) is 1
with pytest.raises(serializers.ValidationError) as exc_info:
@@ -1638,24 +1665,15 @@ class TestChoiceFieldWithType(FieldValues):
Valid and invalid values for a `Choice` field that uses an integer type,
instead of a char type.
"""
- valid_inputs = {
- '1': 1,
- 3: 3,
- }
+
+ valid_inputs = {"1": 1, 3: 3}
invalid_inputs = {
5: ['"5" is not a valid choice.'],
- 'abc': ['"abc" is not a valid choice.']
- }
- outputs = {
- '1': 1,
- 1: 1
+ "abc": ['"abc" is not a valid choice.'],
}
+ outputs = {"1": 1, 1: 1}
field = serializers.ChoiceField(
- choices=[
- (1, 'Poor quality'),
- (2, 'Medium quality'),
- (3, 'Good quality'),
- ]
+ choices=[(1, "Poor quality"), (2, "Medium quality"), (3, "Good quality")]
)
@@ -1664,18 +1682,11 @@ class TestChoiceFieldWithListChoices(FieldValues):
Valid and invalid values for a `Choice` field that uses a flat list for the
choices, rather than a list of pairs of (`value`, `description`).
"""
- valid_inputs = {
- 'poor': 'poor',
- 'medium': 'medium',
- 'good': 'good',
- }
- invalid_inputs = {
- 'awful': ['"awful" is not a valid choice.']
- }
- outputs = {
- 'good': 'good'
- }
- field = serializers.ChoiceField(choices=('poor', 'medium', 'good'))
+
+ valid_inputs = {"poor": "poor", "medium": "medium", "good": "good"}
+ invalid_inputs = {"awful": ['"awful" is not a valid choice.']}
+ outputs = {"good": "good"}
+ field = serializers.ChoiceField(choices=("poor", "medium", "good"))
class TestChoiceFieldWithGroupedChoices(FieldValues):
@@ -1683,27 +1694,14 @@ class TestChoiceFieldWithGroupedChoices(FieldValues):
Valid and invalid values for a `Choice` field that uses a grouped list for the
choices, rather than a list of pairs of (`value`, `description`).
"""
- valid_inputs = {
- 'poor': 'poor',
- 'medium': 'medium',
- 'good': 'good',
- }
- invalid_inputs = {
- 'awful': ['"awful" is not a valid choice.']
- }
- outputs = {
- 'good': 'good'
- }
+
+ valid_inputs = {"poor": "poor", "medium": "medium", "good": "good"}
+ invalid_inputs = {"awful": ['"awful" is not a valid choice.']}
+ outputs = {"good": "good"}
field = serializers.ChoiceField(
choices=[
- (
- 'Category',
- (
- ('poor', 'Poor quality'),
- ('medium', 'Medium quality'),
- ),
- ),
- ('good', 'Good quality'),
+ ("Category", (("poor", "Poor quality"), ("medium", "Medium quality"))),
+ ("good", "Good quality"),
]
)
@@ -1713,27 +1711,15 @@ class TestChoiceFieldWithMixedChoices(FieldValues):
Valid and invalid values for a `Choice` field that uses a single paired or
grouped.
"""
- valid_inputs = {
- 'poor': 'poor',
- 'medium': 'medium',
- 'good': 'good',
- }
- invalid_inputs = {
- 'awful': ['"awful" is not a valid choice.']
- }
- outputs = {
- 'good': 'good'
- }
+
+ valid_inputs = {"poor": "poor", "medium": "medium", "good": "good"}
+ invalid_inputs = {"awful": ['"awful" is not a valid choice.']}
+ outputs = {"good": "good"}
field = serializers.ChoiceField(
choices=[
- (
- 'Category',
- (
- ('poor', 'Poor quality'),
- ),
- ),
- 'medium',
- ('good', 'Good quality'),
+ ("Category", (("poor", "Poor quality"),)),
+ "medium",
+ ("good", "Good quality"),
]
)
@@ -1742,28 +1728,23 @@ class TestMultipleChoiceField(FieldValues):
"""
Valid and invalid values for `MultipleChoiceField`.
"""
+
valid_inputs = {
(): set(),
- ('aircon',): {'aircon'},
- ('aircon', 'manual'): {'aircon', 'manual'},
+ ("aircon",): {"aircon"},
+ ("aircon", "manual"): {"aircon", "manual"},
}
invalid_inputs = {
- 'abc': ['Expected a list of items but got type "str".'],
- ('aircon', 'incorrect'): ['"incorrect" is not a valid choice.']
+ "abc": ['Expected a list of items but got type "str".'],
+ ("aircon", "incorrect"): ['"incorrect" is not a valid choice.'],
}
- outputs = [
- (['aircon', 'manual', 'incorrect'], {'aircon', 'manual', 'incorrect'})
- ]
+ outputs = [(["aircon", "manual", "incorrect"], {"aircon", "manual", "incorrect"})]
field = serializers.MultipleChoiceField(
- choices=[
- ('aircon', 'AirCon'),
- ('manual', 'Manual drive'),
- ('diesel', 'Diesel'),
- ]
+ choices=[("aircon", "AirCon"), ("manual", "Manual drive"), ("diesel", "Diesel")]
)
def test_against_partial_and_full_updates(self):
- field = serializers.MultipleChoiceField(choices=(('a', 'a'), ('b', 'b')))
+ field = serializers.MultipleChoiceField(choices=(("a", "a"), ("b", "b")))
field.partial = False
assert field.get_value(QueryDict({})) == []
field.partial = True
@@ -1774,37 +1755,35 @@ class TestEmptyMultipleChoiceField(FieldValues):
"""
Invalid values for `MultipleChoiceField(allow_empty=False)`.
"""
- valid_inputs = {
- }
- invalid_inputs = (
- ([], ['This selection may not be empty.']),
- )
- outputs = [
- ]
+
+ valid_inputs = {}
+ invalid_inputs = (([], ["This selection may not be empty."]),)
+ outputs = []
field = serializers.MultipleChoiceField(
choices=[
- ('consistency', 'Consistency'),
- ('availability', 'Availability'),
- ('partition', 'Partition tolerance'),
+ ("consistency", "Consistency"),
+ ("availability", "Availability"),
+ ("partition", "Partition tolerance"),
],
- allow_empty=False
+ allow_empty=False,
)
# File serializers...
+
class MockFile:
- def __init__(self, name='', size=0, url=''):
+ def __init__(self, name="", size=0, url=""):
self.name = name
self.size = size
self.url = url
def __eq__(self, other):
return (
- isinstance(other, MockFile) and
- self.name == other.name and
- self.size == other.size and
- self.url == other.url
+ isinstance(other, MockFile)
+ and self.name == other.name
+ and self.size == other.size
+ and self.url == other.url
)
@@ -1812,18 +1791,25 @@ class TestFileField(FieldValues):
"""
Values for `FileField`.
"""
+
valid_inputs = [
- (MockFile(name='example', size=10), MockFile(name='example', size=10))
+ (MockFile(name="example", size=10), MockFile(name="example", size=10))
]
invalid_inputs = [
- ('invalid', ['The submitted data was not a file. Check the encoding type on the form.']),
- (MockFile(name='example.txt', size=0), ['The submitted file is empty.']),
- (MockFile(name='', size=10), ['No filename could be determined.']),
- (MockFile(name='x' * 100, size=10), ['Ensure this filename has at most 10 characters (it has 100).'])
+ (
+ "invalid",
+ ["The submitted data was not a file. Check the encoding type on the form."],
+ ),
+ (MockFile(name="example.txt", size=0), ["The submitted file is empty."]),
+ (MockFile(name="", size=10), ["No filename could be determined."]),
+ (
+ MockFile(name="x" * 100, size=10),
+ ["Ensure this filename has at most 10 characters (it has 100)."],
+ ),
]
outputs = [
- (MockFile(name='example.txt', url='/example.txt'), '/example.txt'),
- ('', None)
+ (MockFile(name="example.txt", url="/example.txt"), "/example.txt"),
+ ("", None),
]
field = serializers.FileField(max_length=10)
@@ -1832,17 +1818,18 @@ class TestFieldFieldWithName(FieldValues):
"""
Values for `FileField` with a filename output instead of URLs.
"""
+
valid_inputs = {}
invalid_inputs = {}
- outputs = [
- (MockFile(name='example.txt', url='/example.txt'), 'example.txt')
- ]
+ outputs = [(MockFile(name="example.txt", url="/example.txt"), "example.txt")]
field = serializers.FileField(use_url=False)
def ext_validator(value):
- if not value.name.endswith('.png'):
- raise serializers.ValidationError('File extension is not allowed. Allowed extensions is png.')
+ if not value.name.endswith(".png"):
+ raise serializers.ValidationError(
+ "File extension is not allowed. Allowed extensions is png."
+ )
# Stub out mock Django `forms.ImageField` class so we don't *actually*
@@ -1856,8 +1843,8 @@ class PassImageValidation(DjangoImageField):
class FailImageValidation(PassImageValidation):
def to_python(self, value):
- if value.name == 'badimage.png':
- raise serializers.ValidationError(self.error_messages['invalid_image'])
+ if value.name == "badimage.png":
+ raise serializers.ValidationError(self.error_messages["invalid_image"])
return value
@@ -1865,10 +1852,19 @@ class TestInvalidImageField(FieldValues):
"""
Values for an invalid `ImageField`.
"""
+
valid_inputs = {}
invalid_inputs = [
- (MockFile(name='badimage.png', size=10), ['Upload a valid image. The file you uploaded was either not an image or a corrupted image.']),
- (MockFile(name='goodimage.html', size=10), ['File extension is not allowed. Allowed extensions is png.'])
+ (
+ MockFile(name="badimage.png", size=10),
+ [
+ "Upload a valid image. The file you uploaded was either not an image or a corrupted image."
+ ],
+ ),
+ (
+ MockFile(name="goodimage.html", size=10),
+ ["File extension is not allowed. Allowed extensions is png."],
+ ),
]
outputs = {}
field = serializers.ImageField(_DjangoImageField=FailImageValidation)
@@ -1878,8 +1874,9 @@ class TestValidImageField(FieldValues):
"""
Values for an valid `ImageField`.
"""
+
valid_inputs = [
- (MockFile(name='example.png', size=10), MockFile(name='example.png', size=10))
+ (MockFile(name="example.png", size=10), MockFile(name="example.png", size=10))
]
invalid_inputs = {}
outputs = {}
@@ -1888,29 +1885,27 @@ class TestValidImageField(FieldValues):
# Composite serializers...
+
class TestListField(FieldValues):
"""
Values for `ListField` with IntegerField as child.
"""
- valid_inputs = [
- ([1, 2, 3], [1, 2, 3]),
- (['1', '2', '3'], [1, 2, 3]),
- ([], [])
- ]
+
+ valid_inputs = [([1, 2, 3], [1, 2, 3]), (["1", "2", "3"], [1, 2, 3]), ([], [])]
invalid_inputs = [
- ('not a list', ['Expected a list of items but got type "str".']),
- ([1, 2, 'error', 'error'], {2: ['A valid integer is required.'], 3: ['A valid integer is required.']}),
- ({'one': 'two'}, ['Expected a list of items but got type "dict".'])
- ]
- outputs = [
- ([1, 2, 3], [1, 2, 3]),
- (['1', '2', '3'], [1, 2, 3])
+ ("not a list", ['Expected a list of items but got type "str".']),
+ (
+ [1, 2, "error", "error"],
+ {2: ["A valid integer is required."], 3: ["A valid integer is required."]},
+ ),
+ ({"one": "two"}, ['Expected a list of items but got type "dict".']),
]
+ outputs = [([1, 2, 3], [1, 2, 3]), (["1", "2", "3"], [1, 2, 3])]
field = serializers.ListField(child=serializers.IntegerField())
def test_no_source_on_child(self):
with pytest.raises(AssertionError) as exc_info:
- serializers.ListField(child=serializers.IntegerField(source='other'))
+ serializers.ListField(child=serializers.IntegerField(source="other"))
assert str(exc_info.value) == (
"The `source` argument is not meaningful when applied to a `child=` field. "
@@ -1919,40 +1914,45 @@ class TestListField(FieldValues):
def test_collection_types_are_invalid_input(self):
field = serializers.ListField(child=serializers.CharField())
- input_value = ({'one': 'two'})
+ input_value = {"one": "two"}
with pytest.raises(serializers.ValidationError) as exc_info:
field.to_internal_value(input_value)
- assert exc_info.value.detail == ['Expected a list of items but got type "dict".']
+ assert exc_info.value.detail == [
+ 'Expected a list of items but got type "dict".'
+ ]
class TestNestedListField(FieldValues):
"""
Values for nested `ListField` with IntegerField as child.
"""
- valid_inputs = [
- ([[1, 2], [3]], [[1, 2], [3]]),
- ([[]], [[]])
- ]
+
+ valid_inputs = [([[1, 2], [3]], [[1, 2], [3]]), ([[]], [[]])]
invalid_inputs = [
- (['not a list'], {0: ['Expected a list of items but got type "str".']}),
- ([[1, 2, 'error'], ['error']], {0: {2: ['A valid integer is required.']}, 1: {0: ['A valid integer is required.']}}),
- ([{'one': 'two'}], {0: ['Expected a list of items but got type "dict".']})
+ (["not a list"], {0: ['Expected a list of items but got type "str".']}),
+ (
+ [[1, 2, "error"], ["error"]],
+ {
+ 0: {2: ["A valid integer is required."]},
+ 1: {0: ["A valid integer is required."]},
+ },
+ ),
+ ([{"one": "two"}], {0: ['Expected a list of items but got type "dict".']}),
]
- outputs = [
- ([[1, 2], [3]], [[1, 2], [3]]),
- ]
- field = serializers.ListField(child=serializers.ListField(child=serializers.IntegerField()))
+ outputs = [([[1, 2], [3]], [[1, 2], [3]])]
+ field = serializers.ListField(
+ child=serializers.ListField(child=serializers.IntegerField())
+ )
class TestEmptyListField(FieldValues):
"""
Values for `ListField` with allow_empty=False flag.
"""
+
valid_inputs = {}
- invalid_inputs = [
- ([], ['This list may not be empty.'])
- ]
+ invalid_inputs = [([], ["This list may not be empty."])]
outputs = {}
field = serializers.ListField(child=serializers.IntegerField(), allow_empty=False)
@@ -1960,26 +1960,23 @@ class TestEmptyListField(FieldValues):
class TestListFieldLengthLimit(FieldValues):
valid_inputs = ()
invalid_inputs = [
- ((0, 1), ['Ensure this field has at least 3 elements.']),
- ((0, 1, 2, 3, 4, 5), ['Ensure this field has no more than 4 elements.']),
+ ((0, 1), ["Ensure this field has at least 3 elements."]),
+ ((0, 1, 2, 3, 4, 5), ["Ensure this field has no more than 4 elements."]),
]
outputs = ()
- field = serializers.ListField(child=serializers.IntegerField(), min_length=3, max_length=4)
+ field = serializers.ListField(
+ child=serializers.IntegerField(), min_length=3, max_length=4
+ )
class TestUnvalidatedListField(FieldValues):
"""
Values for `ListField` with no `child` argument.
"""
- valid_inputs = [
- ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
- ]
- invalid_inputs = [
- ('not a list', ['Expected a list of items but got type "str".']),
- ]
- outputs = [
- ([1, '2', True, [4, 5, 6]], [1, '2', True, [4, 5, 6]]),
- ]
+
+ valid_inputs = [([1, "2", True, [4, 5, 6]], [1, "2", True, [4, 5, 6]])]
+ invalid_inputs = [("not a list", ['Expected a list of items but got type "str".'])]
+ outputs = [([1, "2", True, [4, 5, 6]], [1, "2", True, [4, 5, 6]])]
field = serializers.ListField()
@@ -1987,21 +1984,24 @@ class TestDictField(FieldValues):
"""
Values for `DictField` with CharField as child.
"""
- valid_inputs = [
- ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
- ]
+
+ valid_inputs = [({"a": 1, "b": "2", 3: 3}, {"a": "1", "b": "2", "3": "3"})]
invalid_inputs = [
- ({'a': 1, 'b': None, 'c': None}, {'b': ['This field may not be null.'], 'c': ['This field may not be null.']}),
- ('not a dict', ['Expected a dictionary of items but got type "str".']),
- ]
- outputs = [
- ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ (
+ {"a": 1, "b": None, "c": None},
+ {
+ "b": ["This field may not be null."],
+ "c": ["This field may not be null."],
+ },
+ ),
+ ("not a dict", ['Expected a dictionary of items but got type "str".']),
]
+ outputs = [({"a": 1, "b": "2", 3: 3}, {"a": "1", "b": "2", "3": "3"})]
field = serializers.DictField(child=serializers.CharField())
def test_no_source_on_child(self):
with pytest.raises(AssertionError) as exc_info:
- serializers.DictField(child=serializers.CharField(source='other'))
+ serializers.DictField(child=serializers.CharField(source="other"))
assert str(exc_info.value) == (
"The `source` argument is not meaningful when applied to a `child=` field. "
@@ -2021,31 +2021,45 @@ class TestNestedDictField(FieldValues):
"""
Values for nested `DictField` with CharField as child.
"""
+
valid_inputs = [
- ({0: {'a': 1, 'b': '2'}, 1: {3: 3}}, {'0': {'a': '1', 'b': '2'}, '1': {'3': '3'}}),
+ (
+ {0: {"a": 1, "b": "2"}, 1: {3: 3}},
+ {"0": {"a": "1", "b": "2"}, "1": {"3": "3"}},
+ )
]
invalid_inputs = [
- ({0: {'a': 1, 'b': None}, 1: {'c': None}}, {'0': {'b': ['This field may not be null.']}, '1': {'c': ['This field may not be null.']}}),
- ({0: 'not a dict'}, {'0': ['Expected a dictionary of items but got type "str".']}),
+ (
+ {0: {"a": 1, "b": None}, 1: {"c": None}},
+ {
+ "0": {"b": ["This field may not be null."]},
+ "1": {"c": ["This field may not be null."]},
+ },
+ ),
+ (
+ {0: "not a dict"},
+ {"0": ['Expected a dictionary of items but got type "str".']},
+ ),
]
outputs = [
- ({0: {'a': 1, 'b': '2'}, 1: {3: 3}}, {'0': {'a': '1', 'b': '2'}, '1': {'3': '3'}}),
+ (
+ {0: {"a": 1, "b": "2"}, 1: {3: 3}},
+ {"0": {"a": "1", "b": "2"}, "1": {"3": "3"}},
+ )
]
- field = serializers.DictField(child=serializers.DictField(child=serializers.CharField()))
+ field = serializers.DictField(
+ child=serializers.DictField(child=serializers.CharField())
+ )
class TestDictFieldWithNullChild(FieldValues):
"""
Values for `DictField` with allow_null CharField as child.
"""
- valid_inputs = [
- ({'a': None, 'b': '2', 3: 3}, {'a': None, 'b': '2', '3': '3'}),
- ]
- invalid_inputs = [
- ]
- outputs = [
- ({'a': None, 'b': '2', 3: 3}, {'a': None, 'b': '2', '3': '3'}),
- ]
+
+ valid_inputs = [({"a": None, "b": "2", 3: 3}, {"a": None, "b": "2", "3": "3"})]
+ invalid_inputs = []
+ outputs = [({"a": None, "b": "2", 3: 3}, {"a": None, "b": "2", "3": "3"})]
field = serializers.DictField(child=serializers.CharField(allow_null=True))
@@ -2053,15 +2067,14 @@ class TestUnvalidatedDictField(FieldValues):
"""
Values for `DictField` with no `child` argument.
"""
+
valid_inputs = [
- ({'a': 1, 'b': [4, 5, 6], 1: 123}, {'a': 1, 'b': [4, 5, 6], '1': 123}),
+ ({"a": 1, "b": [4, 5, 6], 1: 123}, {"a": 1, "b": [4, 5, 6], "1": 123})
]
invalid_inputs = [
- ('not a dict', ['Expected a dictionary of items but got type "str".']),
- ]
- outputs = [
- ({'a': 1, 'b': [4, 5, 6]}, {'a': 1, 'b': [4, 5, 6]}),
+ ("not a dict", ['Expected a dictionary of items but got type "str".'])
]
+ outputs = [({"a": 1, "b": [4, 5, 6]}, {"a": 1, "b": [4, 5, 6]})]
field = serializers.DictField()
@@ -2069,16 +2082,15 @@ class TestHStoreField(FieldValues):
"""
Values for `ListField` with CharField as child.
"""
+
valid_inputs = [
- ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
- ({'a': 1, 'b': None}, {'a': '1', 'b': None}),
+ ({"a": 1, "b": "2", 3: 3}, {"a": "1", "b": "2", "3": "3"}),
+ ({"a": 1, "b": None}, {"a": "1", "b": None}),
]
invalid_inputs = [
- ('not a dict', ['Expected a dictionary of items but got type "str".']),
- ]
- outputs = [
- ({'a': 1, 'b': '2', 3: 3}, {'a': '1', 'b': '2', '3': '3'}),
+ ("not a dict", ['Expected a dictionary of items but got type "str".'])
]
+ outputs = [({"a": 1, "b": "2", 3: 3}, {"a": "1", "b": "2", "3": "3"})]
field = serializers.HStoreField()
def test_child_is_charfield(self):
@@ -2092,7 +2104,7 @@ class TestHStoreField(FieldValues):
def test_no_source_on_child(self):
with pytest.raises(AssertionError) as exc_info:
- serializers.HStoreField(child=serializers.CharField(source='other'))
+ serializers.HStoreField(child=serializers.CharField(source="other"))
assert str(exc_info.value) == (
"The `source` argument is not meaningful when applied to a `child=` field. "
@@ -2112,31 +2124,22 @@ class TestJSONField(FieldValues):
"""
Values for `JSONField`.
"""
+
valid_inputs = [
- ({
- 'a': 1,
- 'b': ['some', 'list', True, 1.23],
- '3': None
- }, {
- 'a': 1,
- 'b': ['some', 'list', True, 1.23],
- '3': None
- }),
+ (
+ {"a": 1, "b": ["some", "list", True, 1.23], "3": None},
+ {"a": 1, "b": ["some", "list", True, 1.23], "3": None},
+ )
]
invalid_inputs = [
- ({'a': set()}, ['Value must be valid JSON.']),
- ({'a': float('inf')}, ['Value must be valid JSON.']),
+ ({"a": set()}, ["Value must be valid JSON."]),
+ ({"a": float("inf")}, ["Value must be valid JSON."]),
]
outputs = [
- ({
- 'a': 1,
- 'b': ['some', 'list', True, 1.23],
- '3': 3
- }, {
- 'a': 1,
- 'b': ['some', 'list', True, 1.23],
- '3': 3
- }),
+ (
+ {"a": 1, "b": ["some", "list", True, 1.23], "3": 3},
+ {"a": 1, "b": ["some", "list", True, 1.23], "3": 3},
+ )
]
field = serializers.JSONField()
@@ -2144,72 +2147,69 @@ class TestJSONField(FieldValues):
"""
HTML inputs should be treated as a serialized JSON string.
"""
+
class TestSerializer(serializers.Serializer):
config = serializers.JSONField()
data = QueryDict(mutable=True)
- data.update({'config': '{"a":1}'})
+ data.update({"config": '{"a":1}'})
serializer = TestSerializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'config': {"a": 1}}
+ assert serializer.validated_data == {"config": {"a": 1}}
class TestBinaryJSONField(FieldValues):
"""
Values for `JSONField` with binary=True.
"""
+
valid_inputs = [
- (b'{"a": 1, "3": null, "b": ["some", "list", true, 1.23]}', {
- 'a': 1,
- 'b': ['some', 'list', True, 1.23],
- '3': None
- }),
- ]
- invalid_inputs = [
- ('{"a": "unterminated string}', ['Value must be valid JSON.']),
- ]
- outputs = [
- (['some', 'list', True, 1.23], b'["some", "list", true, 1.23]'),
+ (
+ b'{"a": 1, "3": null, "b": ["some", "list", true, 1.23]}',
+ {"a": 1, "b": ["some", "list", True, 1.23], "3": None},
+ )
]
+ invalid_inputs = [('{"a": "unterminated string}', ["Value must be valid JSON."])]
+ outputs = [(["some", "list", True, 1.23], b'["some", "list", true, 1.23]')]
field = serializers.JSONField(binary=True)
# Tests for FieldField.
# ---------------------
+
class MockRequest:
def build_absolute_uri(self, value):
- return 'http://example.com' + value
+ return "http://example.com" + value
class TestFileFieldContext:
def test_fully_qualified_when_request_in_context(self):
field = serializers.FileField(max_length=10)
- field._context = {'request': MockRequest()}
- obj = MockFile(name='example.txt', url='/example.txt')
+ field._context = {"request": MockRequest()}
+ obj = MockFile(name="example.txt", url="/example.txt")
value = field.to_representation(obj)
- assert value == 'http://example.com/example.txt'
+ assert value == "http://example.com/example.txt"
# Tests for SerializerMethodField.
# --------------------------------
+
class TestSerializerMethodField:
def test_serializer_method_field(self):
class ExampleSerializer(serializers.Serializer):
example_field = serializers.SerializerMethodField()
def get_example_field(self, obj):
- return 'ran get_example_field(%d)' % obj['example_field']
+ return "ran get_example_field(%d)" % obj["example_field"]
- serializer = ExampleSerializer({'example_field': 123})
- assert serializer.data == {
- 'example_field': 'ran get_example_field(123)'
- }
+ serializer = ExampleSerializer({"example_field": 123})
+ assert serializer.data == {"example_field": "ran get_example_field(123)"}
def test_redundant_method_name(self):
class ExampleSerializer(serializers.Serializer):
- example_field = serializers.SerializerMethodField('get_example_field')
+ example_field = serializers.SerializerMethodField("get_example_field")
with pytest.raises(AssertionError) as exc_info:
ExampleSerializer().fields
@@ -2222,88 +2222,73 @@ class TestSerializerMethodField:
class TestValidationErrorCode:
- @pytest.mark.parametrize('use_list', (False, True))
+ @pytest.mark.parametrize("use_list", (False, True))
def test_validationerror_code_with_msg(self, use_list):
-
class ExampleSerializer(serializers.Serializer):
password = serializers.CharField()
def validate_password(self, obj):
- err = DjangoValidationError('exc_msg', code='exc_code')
+ err = DjangoValidationError("exc_msg", code="exc_code")
if use_list:
err = DjangoValidationError([err])
raise err
- serializer = ExampleSerializer(data={'password': 123})
+ serializer = ExampleSerializer(data={"password": 123})
serializer.is_valid()
- assert serializer.errors == {'password': ['exc_msg']}
- assert serializer.errors['password'][0].code == 'exc_code'
+ assert serializer.errors == {"password": ["exc_msg"]}
+ assert serializer.errors["password"][0].code == "exc_code"
- @pytest.mark.parametrize('code', (None, 'exc_code',))
- @pytest.mark.parametrize('use_list', (False, True))
+ @pytest.mark.parametrize("code", (None, "exc_code"))
+ @pytest.mark.parametrize("use_list", (False, True))
def test_validationerror_code_with_dict(self, use_list, code):
-
class ExampleSerializer(serializers.Serializer):
-
def validate(self, obj):
if code is None:
- err = DjangoValidationError({
- 'email': 'email error',
- })
+ err = DjangoValidationError({"email": "email error"})
else:
- err = DjangoValidationError({
- 'email': DjangoValidationError(
- 'email error',
- code=code),
- })
+ err = DjangoValidationError(
+ {"email": DjangoValidationError("email error", code=code)}
+ )
if use_list:
err = DjangoValidationError([err])
raise err
serializer = ExampleSerializer(data={})
serializer.is_valid()
- expected_code = code if code else 'invalid'
+ expected_code = code if code else "invalid"
if use_list:
assert serializer.errors == {
- 'non_field_errors': [
- exceptions.ErrorDetail(
- string='email error',
- code=expected_code
- )
+ "non_field_errors": [
+ exceptions.ErrorDetail(string="email error", code=expected_code)
]
}
else:
- assert serializer.errors == {
- 'email': ['email error'],
- }
- assert serializer.errors['email'][0].code == expected_code
+ assert serializer.errors == {"email": ["email error"]}
+ assert serializer.errors["email"][0].code == expected_code
- @pytest.mark.parametrize('code', (None, 'exc_code',))
+ @pytest.mark.parametrize("code", (None, "exc_code"))
def test_validationerror_code_with_dict_list_same_code(self, code):
-
class ExampleSerializer(serializers.Serializer):
-
def validate(self, obj):
if code is None:
- raise DjangoValidationError({'email': ['email error 1',
- 'email error 2']})
- raise DjangoValidationError({'email': [
- DjangoValidationError('email error 1', code=code),
- DjangoValidationError('email error 2', code=code),
- ]})
+ raise DjangoValidationError(
+ {"email": ["email error 1", "email error 2"]}
+ )
+ raise DjangoValidationError(
+ {
+ "email": [
+ DjangoValidationError("email error 1", code=code),
+ DjangoValidationError("email error 2", code=code),
+ ]
+ }
+ )
serializer = ExampleSerializer(data={})
serializer.is_valid()
- expected_code = code if code else 'invalid'
+ expected_code = code if code else "invalid"
assert serializer.errors == {
- 'email': [
- exceptions.ErrorDetail(
- string='email error 1',
- code=expected_code
- ),
- exceptions.ErrorDetail(
- string='email error 2',
- code=expected_code
- ),
+ "email": [
+ exceptions.ErrorDetail(string="email error 1", code=expected_code),
+ exceptions.ErrorDetail(string="email error 2", code=expected_code),
]
}
diff --git a/tests/test_filters.py b/tests/test_filters.py
index a53fa192a..dac378321 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -14,6 +14,7 @@ from rest_framework import filters, generics, serializers
from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
@@ -30,7 +31,7 @@ class BaseFilterTests(TestCase):
with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None)
- @pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
+ @pytest.mark.skipif(not coreschema, reason="coreschema is not installed")
def test_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None
with pytest.raises(AssertionError):
@@ -47,7 +48,7 @@ class SearchFilterModel(models.Model):
class SearchFilterSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModel
- fields = '__all__'
+ fields = "__all__"
class SearchFilterTests(TestCase):
@@ -59,12 +60,8 @@ class SearchFilterTests(TestCase):
# zzz cde
# ...
for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
+ title = "z" * (idx + 1)
+ text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
SearchFilterModel(title=title, text=text).save()
def test_search(self):
@@ -72,14 +69,14 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
+ search_fields = ("title", "text")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
+ request = factory.get("/", {"search": "b"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ {"id": 1, "title": "z", "text": "abc"},
+ {"id": 2, "title": "zz", "text": "bcd"},
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
@@ -89,10 +86,11 @@ class SearchFilterTests(TestCase):
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
- expected = SearchFilterSerializer(SearchFilterModel.objects.all(),
- many=True).data
+ expected = SearchFilterSerializer(
+ SearchFilterModel.objects.all(), many=True
+ ).data
assert response.data == expected
def test_exact_search(self):
@@ -100,59 +98,53 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text')
+ search_fields = ("=title", "text")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'zzz'})
+ request = factory.get("/", {"search": "zzz"})
response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
+ assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('title', '^text')
+ search_fields = ("title", "^text")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'b'})
+ request = factory.get("/", {"search": "b"})
response = view(request)
- assert response.data == [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
+ assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('$title', '$text')
+ search_fields = ("$title", "$text")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'z{2} ^b'})
+ request = factory.get("/", {"search": "z{2} ^b"})
response = view(request)
- assert response.data == [
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
- ]
+ assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
def test_search_with_nonstandard_search_param(self):
- with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
+ with override_settings(REST_FRAMEWORK={"SEARCH_PARAM": "query"}):
reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('title', 'text')
+ search_fields = ("title", "text")
view = SearchListView.as_view()
- request = factory.get('/', {'query': 'b'})
+ request = factory.get("/", {"query": "b"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'z', 'text': 'abc'},
- {'id': 2, 'title': 'zz', 'text': 'bcd'}
+ {"id": 1, "title": "z", "text": "abc"},
+ {"id": 2, "title": "zz", "text": "bcd"},
]
reload_module(filters)
@@ -161,26 +153,24 @@ class SearchFilterTests(TestCase):
class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields
def get_search_fields(self, view, request):
- if request.query_params.get('title_only'):
- return ('$title',)
+ if request.query_params.get("title_only"):
+ return ("$title",)
return super(CustomSearchFilter, self).get_search_fields(view, request)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
- search_fields = ('$title', '$text')
+ search_fields = ("$title", "$text")
view = SearchListView.as_view()
- request = factory.get('/', {'search': r'^\w{3}$'})
+ request = factory.get("/", {"search": r"^\w{3}$"})
response = view(request)
assert len(response.data) == 10
- request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
+ request = factory.get("/", {"search": r"^\w{3}$", "title_only": "true"})
response = view(request)
- assert response.data == [
- {'id': 3, 'title': 'zzz', 'text': 'cde'}
- ]
+ assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
class AttributeModel(models.Model):
@@ -195,33 +185,31 @@ class SearchFilterModelFk(models.Model):
class SearchFilterFkSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModelFk
- fields = '__all__'
+ fields = "__all__"
class SearchFilterFkTests(TestCase):
-
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
+ prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
- SearchFilterModelFk._meta,
- ["%stitle" % prefix]
+ SearchFilterModelFk._meta, ["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
- ["%stitle" % prefix, "%sattribute__label" % prefix]
+ ["%stitle" % prefix, "%sattribute__label" % prefix],
)
def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
+ prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
- ["%sattribute__label" % prefix, "%stitle" % prefix]
+ ["%sattribute__label" % prefix, "%stitle" % prefix],
)
@@ -234,7 +222,7 @@ class SearchFilterModelM2M(models.Model):
class SearchFilterM2MSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModelM2M
- fields = '__all__'
+ fields = "__all__"
class SearchFilterM2MTests(TestCase):
@@ -246,43 +234,38 @@ class SearchFilterM2MTests(TestCase):
# zzz cde [1, 2, 3]
# ...
for idx in range(3):
- label = 'w' * (idx + 1)
+ label = "w" * (idx + 1)
AttributeModel.objects.create(label=label)
for idx in range(10):
- title = 'z' * (idx + 1)
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
+ title = "z" * (idx + 1)
+ text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
SearchFilterModelM2M(title=title, text=text).save()
- SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
+ SearchFilterModelM2M.objects.get(title="zz").attributes.add(1, 2, 3)
def test_m2m_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('=title', 'text', 'attributes__label')
+ search_fields = ("=title", "text", "attributes__label")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'zz'})
+ request = factory.get("/", {"search": "zz"})
response = view(request)
assert len(response.data) == 1
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
- prefixes = [''] + list(filter_.lookup_prefixes)
+ prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
- SearchFilterModelM2M._meta,
- ["%stitle" % prefix]
+ SearchFilterModelM2M._meta, ["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
- ["%stitle" % prefix, "%sattributes__label" % prefix]
+ ["%stitle" % prefix, "%sattributes__label" % prefix],
)
@@ -299,33 +282,46 @@ class Entry(models.Model):
class BlogSerializer(serializers.ModelSerializer):
class Meta:
model = Blog
- fields = '__all__'
+ fields = "__all__"
class SearchFilterToManyTests(TestCase):
-
@classmethod
def setUpTestData(cls):
- b1 = Blog.objects.create(name='Blog 1')
- b2 = Blog.objects.create(name='Blog 2')
+ b1 = Blog.objects.create(name="Blog 1")
+ b2 = Blog.objects.create(name="Blog 2")
# Multiple entries on Lennon published in 1979 - distinct should deduplicate
- Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
- Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
+ Entry.objects.create(
+ blog=b1,
+ headline="Something about Lennon",
+ pub_date=datetime.date(1979, 1, 1),
+ )
+ Entry.objects.create(
+ blog=b1,
+ headline="Another thing about Lennon",
+ pub_date=datetime.date(1979, 6, 1),
+ )
# Entry on Lennon *and* a separate entry in 1979 - should not match
- Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
- Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
+ Entry.objects.create(
+ blog=b2, headline="Something unrelated", pub_date=datetime.date(1979, 1, 1)
+ )
+ Entry.objects.create(
+ blog=b2,
+ headline="Retrospective on Lennon",
+ pub_date=datetime.date(1990, 6, 1),
+ )
def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all()
serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
+ search_fields = ("=name", "entry__headline", "=entry__pub_date__year")
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'Lennon,1979'})
+ request = factory.get("/", {"search": "Lennon,1979"})
response = view(request)
assert len(response.data) == 1
@@ -335,60 +331,58 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModel
- fields = ('title', 'text', 'title_text')
+ fields = ("title", "text", "title_text")
class SearchFilterAnnotatedFieldTests(TestCase):
@classmethod
def setUpTestData(cls):
- SearchFilterModel.objects.create(title='abc', text='def')
- SearchFilterModel.objects.create(title='ghi', text='jkl')
+ SearchFilterModel.objects.create(title="abc", text="def")
+ SearchFilterModel.objects.create(title="ghi", text="jkl")
def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate(
- title_text=Upper(
- Concat(models.F('title'), models.F('text'))
- )
+ title_text=Upper(Concat(models.F("title"), models.F("text")))
).all()
serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,)
- search_fields = ('title_text',)
+ search_fields = ("title_text",)
view = SearchListView.as_view()
- request = factory.get('/', {'search': 'ABCDEF'})
+ request = factory.get("/", {"search": "ABCDEF"})
response = view(request)
assert len(response.data) == 1
- assert response.data[0]['title_text'] == 'ABCDEF'
+ assert response.data[0]["title_text"] == "ABCDEF"
class OrderingFilterModel(models.Model):
- title = models.CharField(max_length=20, verbose_name='verbose title')
+ title = models.CharField(max_length=20, verbose_name="verbose title")
text = models.CharField(max_length=100)
class OrderingFilterRelatedModel(models.Model):
- related_object = models.ForeignKey(OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE)
- index = models.SmallIntegerField(help_text="A non-related field to test with", default=0)
+ related_object = models.ForeignKey(
+ OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE
+ )
+ index = models.SmallIntegerField(
+ help_text="A non-related field to test with", default=0
+ )
class OrderingFilterSerializer(serializers.ModelSerializer):
class Meta:
model = OrderingFilterModel
- fields = '__all__'
+ fields = "__all__"
class OrderingDottedRelatedSerializer(serializers.ModelSerializer):
- related_text = serializers.CharField(source='related_object.text')
- related_title = serializers.CharField(source='related_object.title')
+ related_text = serializers.CharField(source="related_object.text")
+ related_title = serializers.CharField(source="related_object.title")
class Meta:
model = OrderingFilterRelatedModel
- fields = (
- 'related_text',
- 'related_title',
- 'index',
- )
+ fields = ("related_text", "related_title", "index")
class DjangoFilterOrderingModel(models.Model):
@@ -396,13 +390,13 @@ class DjangoFilterOrderingModel(models.Model):
text = models.CharField(max_length=10)
class Meta:
- ordering = ['-date']
+ ordering = ["-date"]
class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
class Meta:
model = DjangoFilterOrderingModel
- fields = '__all__'
+ fields = "__all__"
class OrderingFilterTests(TestCase):
@@ -413,16 +407,8 @@ class OrderingFilterTests(TestCase):
# yxw bcd
# xwv cde
for idx in range(3):
- title = (
- chr(ord('z') - idx) +
- chr(ord('y') - idx) +
- chr(ord('x') - idx)
- )
- text = (
- chr(idx + ord('a')) +
- chr(idx + ord('b')) +
- chr(idx + ord('c'))
- )
+ title = chr(ord("z") - idx) + chr(ord("y") - idx) + chr(ord("x") - idx)
+ text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
OrderingFilterModel(title=title, text=text).save()
def test_ordering(self):
@@ -430,16 +416,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
+ request = factory.get("/", {"ordering": "text"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {"id": 1, "title": "zyx", "text": "abc"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 3, "title": "xwv", "text": "cde"},
]
def test_reverse_ordering(self):
@@ -447,16 +433,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-text'})
+ request = factory.get("/", {"ordering": "-text"})
response = view(request)
assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 1, "title": "zyx", "text": "abc"},
]
def test_incorrecturl_extrahyphens_ordering(self):
@@ -464,16 +450,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '--text'})
+ request = factory.get("/", {"ordering": "--text"})
response = view(request)
assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 1, "title": "zyx", "text": "abc"},
]
def test_incorrectfield_ordering(self):
@@ -481,16 +467,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'foobar'})
+ request = factory.get("/", {"ordering": "foobar"})
response = view(request)
assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 1, "title": "zyx", "text": "abc"},
]
def test_default_ordering(self):
@@ -498,16 +484,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('')
+ request = factory.get("")
response = view(request)
assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 1, "title": "zyx", "text": "abc"},
]
def test_default_ordering_using_string(self):
@@ -515,53 +501,48 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = ('text',)
+ ordering = "title"
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('')
+ request = factory.get("")
response = view(request)
assert response.data == [
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 1, "title": "zyx", "text": "abc"},
]
def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by
num_objs = [2, 5, 3]
- for obj, num_relateds in zip(OrderingFilterModel.objects.all(),
- num_objs):
+ for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
for _ in range(num_relateds):
- new_related = OrderingFilterRelatedModel(
- related_object=obj
- )
+ new_related = OrderingFilterRelatedModel(related_object=obj)
new_related.save()
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = 'title'
- ordering_fields = '__all__'
+ ordering = "title"
+ ordering_fields = "__all__"
queryset = OrderingFilterModel.objects.all().annotate(
- models.Count("relateds"))
+ models.Count("relateds")
+ )
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'relateds__count'})
+ request = factory.get("/", {"ordering": "relateds__count"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
+ {"id": 1, "title": "zyx", "text": "abc"},
+ {"id": 3, "title": "xwv", "text": "cde"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
]
def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()):
- OrderingFilterRelatedModel.objects.create(
- related_object=obj,
- index=index
- )
+ OrderingFilterRelatedModel.objects.create(related_object=obj, index=index)
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer
@@ -569,62 +550,62 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'related_object__text'})
+ request = factory.get("/", {"ordering": "related_object__text"})
response = view(request)
assert response.data == [
- {'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
- {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
- {'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
+ {"related_title": "zyx", "related_text": "abc", "index": 0},
+ {"related_title": "yxw", "related_text": "bcd", "index": 1},
+ {"related_title": "xwv", "related_text": "cde", "index": 2},
]
- request = factory.get('/', {'ordering': '-index'})
+ request = factory.get("/", {"ordering": "-index"})
response = view(request)
assert response.data == [
- {'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
- {'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
- {'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
+ {"related_title": "xwv", "related_text": "cde", "index": 2},
+ {"related_title": "yxw", "related_text": "bcd", "index": 1},
+ {"related_title": "zyx", "related_text": "abc", "index": 0},
]
def test_ordering_with_nonstandard_ordering_param(self):
- with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
+ with override_settings(REST_FRAMEWORK={"ORDERING_PARAM": "order"}):
reload_module(filters)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
- ordering_fields = ('text',)
+ ordering = ("title",)
+ ordering_fields = ("text",)
view = OrderingListView.as_view()
- request = factory.get('/', {'order': 'text'})
+ request = factory.get("/", {"order": "text"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {"id": 1, "title": "zyx", "text": "abc"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 3, "title": "xwv", "text": "cde"},
]
reload_module(filters)
def test_get_template_context(self):
class OrderingListView(generics.ListAPIView):
- ordering_fields = '__all__'
+ ordering_fields = "__all__"
serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
- request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
+ request = factory.get("/", {"ordering": "title"}, HTTP_ACCEPT="text/html")
view = OrderingListView.as_view()
response = view(request)
- self.assertContains(response, 'verbose title')
+ self.assertContains(response, "verbose title")
def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
+ ordering = ("title",)
# note: no ordering_fields and serializer_class specified
@@ -632,24 +613,24 @@ class OrderingFilterTests(TestCase):
return OrderingFilterSerializer
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
+ request = factory.get("/", {"ordering": "text"})
response = view(request)
assert response.data == [
- {'id': 1, 'title': 'zyx', 'text': 'abc'},
- {'id': 2, 'title': 'yxw', 'text': 'bcd'},
- {'id': 3, 'title': 'xwv', 'text': 'cde'},
+ {"id": 1, "title": "zyx", "text": "abc"},
+ {"id": 2, "title": "yxw", "text": "bcd"},
+ {"id": 3, "title": "xwv", "text": "cde"},
]
def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
- ordering = ('title',)
+ ordering = ("title",)
# note: no ordering_fields and serializer_class
# or get_serializer_class specified
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'text'})
+ request = factory.get("/", {"ordering": "text"})
with self.assertRaises(ImproperlyConfigured):
view(request)
@@ -666,7 +647,7 @@ class SensitiveDataSerializer1(serializers.ModelSerializer):
class Meta:
model = SensitiveOrderingFilterModel
- fields = ('id', 'username')
+ fields = ("id", "username")
class SensitiveDataSerializer2(serializers.ModelSerializer):
@@ -675,74 +656,80 @@ class SensitiveDataSerializer2(serializers.ModelSerializer):
class Meta:
model = SensitiveOrderingFilterModel
- fields = ('id', 'username', 'password')
+ fields = ("id", "username", "password")
class SensitiveDataSerializer3(serializers.ModelSerializer):
- user = serializers.CharField(source='username')
+ user = serializers.CharField(source="username")
class Meta:
model = SensitiveOrderingFilterModel
- fields = ('id', 'user')
+ fields = ("id", "user")
class SensitiveOrderingFilterTests(TestCase):
def setUp(self):
for idx in range(3):
- username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
- password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
+ username = {0: "userA", 1: "userB", 2: "userC"}[idx]
+ password = {0: "passA", 1: "passC", 2: "passB"}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
- SensitiveDataSerializer3
+ SensitiveDataSerializer3,
]:
+
class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by(
+ "username"
+ )
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': '-username'})
+ request = factory.get("/", {"ordering": "-username"})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
+ username_field = "user"
else:
- username_field = 'username'
+ username_field = "username"
# Note: Inverse username ordering correctly applied.
assert response.data == [
- {'id': 3, username_field: 'userC'},
- {'id': 2, username_field: 'userB'},
- {'id': 1, username_field: 'userA'},
+ {"id": 3, username_field: "userC"},
+ {"id": 2, username_field: "userB"},
+ {"id": 1, username_field: "userA"},
]
def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
- SensitiveDataSerializer3
+ SensitiveDataSerializer3,
]:
+
class OrderingListView(generics.ListAPIView):
- queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
+ queryset = SensitiveOrderingFilterModel.objects.all().order_by(
+ "username"
+ )
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
- request = factory.get('/', {'ordering': 'password'})
+ request = factory.get("/", {"ordering": "password"})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
- username_field = 'user'
+ username_field = "user"
else:
- username_field = 'username'
+ username_field = "username"
# Note: The passwords are not in order. Default ordering is used.
assert response.data == [
- {'id': 1, username_field: 'userA'}, # PassB
- {'id': 2, username_field: 'userB'}, # PassC
- {'id': 3, username_field: 'userC'}, # PassA
+ {"id": 1, username_field: "userA"}, # PassB
+ {"id": 2, username_field: "userB"}, # PassC
+ {"id": 3, username_field: "userC"}, # PassA
]
diff --git a/tests/test_generateschema.py b/tests/test_generateschema.py
index 915c6ea05..fab1ee966 100644
--- a/tests/test_generateschema.py
+++ b/tests/test_generateschema.py
@@ -17,20 +17,18 @@ class FooView(APIView):
pass
-urlpatterns = [
- url(r'^$', FooView.as_view())
-]
+urlpatterns = [url(r"^$", FooView.as_view())]
-@override_settings(ROOT_URLCONF='tests.test_generateschema')
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@override_settings(ROOT_URLCONF="tests.test_generateschema")
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema."""
def setUp(self):
self.out = six.StringIO()
- @pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.')
+ @pytest.mark.skipif(six.PY2, reason="PyYAML unicode output is malformed on PY2.")
def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info:
description: Sample description
@@ -44,45 +42,29 @@ class GenerateSchemaTests(TestCase):
servers:
- url: http://api.sample.com/
"""
- call_command('generateschema',
- '--title=SampleAPI',
- '--url=http://api.sample.com',
- '--description=Sample description',
- stdout=self.out)
+ call_command(
+ "generateschema",
+ "--title=SampleAPI",
+ "--url=http://api.sample.com",
+ "--description=Sample description",
+ stdout=self.out,
+ )
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self):
expected_out = {
"openapi": "3.0.0",
- "info": {
- "version": "",
- "title": "",
- "description": ""
- },
- "servers": [
- {
- "url": ""
- }
- ],
- "paths": {
- "/": {
- "get": {
- "operationId": "list"
- }
- }
- }
+ "info": {"version": "", "title": "", "description": ""},
+ "servers": [{"url": ""}],
+ "paths": {"/": {"get": {"operationId": "list"}}},
}
- call_command('generateschema',
- '--format=openapi-json',
- stdout=self.out)
+ call_command("generateschema", "--format=openapi-json", stdout=self.out)
out_json = json.loads(self.out.getvalue())
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
- call_command('generateschema',
- '--format=corejson',
- stdout=self.out)
+ call_command("generateschema", "--format=corejson", stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())
diff --git a/tests/test_generics.py b/tests/test_generics.py
index c0ff1c5c4..29260cc22 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -11,10 +11,14 @@ from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from tests.models import (
- BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
- UUIDForeignKeyTarget
+ BasicModel,
+ ForeignKeySource,
+ ForeignKeyTarget,
+ RESTFrameworkModel,
+ UUIDForeignKeyTarget,
)
+
factory = APIRequestFactory()
@@ -35,13 +39,13 @@ class Comment(RESTFrameworkModel):
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
- fields = '__all__'
+ fields = "__all__"
class ForeignKeySerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
- fields = '__all__'
+ fields = "__all__"
class SlugSerializer(serializers.ModelSerializer):
@@ -49,7 +53,7 @@ class SlugSerializer(serializers.ModelSerializer):
class Meta:
model = SlugBasedModel
- fields = ('text', 'slug')
+ fields = ("text", "slug")
# Views
@@ -59,7 +63,7 @@ class RootView(generics.ListCreateAPIView):
class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- queryset = BasicModel.objects.exclude(text='filtered out')
+ queryset = BasicModel.objects.exclude(text="filtered out")
serializer_class = BasicSerializer
@@ -72,9 +76,10 @@ class SlugBasedInstanceView(InstanceView):
"""
A model with a slug-field.
"""
+
queryset = SlugBasedModel.objects.all()
serializer_class = SlugSerializer
- lookup_field = 'slug'
+ lookup_field = "slug"
# Tests
@@ -83,21 +88,18 @@ class TestRootView(TestCase):
"""
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = RootView.as_view()
def test_get_root_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
- request = factory.get('/')
+ request = factory.get("/")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
@@ -107,7 +109,7 @@ class TestRootView(TestCase):
"""
HEAD requests to ListCreateAPIView should return 200.
"""
- request = factory.head('/')
+ request = factory.head("/")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
@@ -116,21 +118,21 @@ class TestRootView(TestCase):
"""
POST requests to ListCreateAPIView should create a new object.
"""
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
- assert response.data == {'id': 4, 'text': 'foobar'}
+ assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4)
- assert created.text == 'foobar'
+ assert created.text == "foobar"
def test_put_root_view(self):
"""
PUT requests to ListCreateAPIView should not be allowed
"""
- data = {'text': 'foobar'}
- request = factory.put('/', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.put("/", data, format="json")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@@ -140,7 +142,7 @@ class TestRootView(TestCase):
"""
DELETE requests to ListCreateAPIView should not be allowed
"""
- request = factory.delete('/')
+ request = factory.delete("/")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@@ -150,24 +152,24 @@ class TestRootView(TestCase):
"""
POST requests to create a new object should not be able to set the id.
"""
- data = {'id': 999, 'text': 'foobar'}
- request = factory.post('/', data, format='json')
+ data = {"id": 999, "text": "foobar"}
+ request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
- assert response.data == {'id': 4, 'text': 'foobar'}
+ assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4)
- assert created.text == 'foobar'
+ assert created.text == "foobar"
def test_post_error_root_view(self):
"""
POST requests to ListCreateAPIView in HTML should include a form error.
"""
- data = {'text': 'foobar' * 100}
- request = factory.post('/', data, HTTP_ACCEPT='text/html')
+ data = {"text": "foobar" * 100}
+ request = factory.post("/", data, HTTP_ACCEPT="text/html")
response = self.view(request).render()
expected_error = 'Ensure this field has no more than 100 characters.'
- assert expected_error in response.rendered_content.decode('utf-8')
+ assert expected_error in response.rendered_content.decode("utf-8")
EXPECTED_QUERIES_FOR_PUT = 2
@@ -178,14 +180,11 @@ class TestInstanceView(TestCase):
"""
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz', 'filtered out']
+ items = ["foo", "bar", "baz", "filtered out"]
for item in items:
BasicModel(text=item).save()
- self.objects = BasicModel.objects.exclude(text='filtered out')
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+ self.objects = BasicModel.objects.exclude(text="filtered out")
+ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
@@ -193,7 +192,7 @@ class TestInstanceView(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
- request = factory.get('/1')
+ request = factory.get("/1")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
@@ -203,8 +202,8 @@ class TestInstanceView(TestCase):
"""
POST requests to RetrieveUpdateDestroyAPIView should not be allowed
"""
- data = {'text': 'foobar'}
- request = factory.post('/', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.post("/", data, format="json")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@@ -214,38 +213,38 @@ class TestInstanceView(TestCase):
"""
PUT requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
- response = self.view(request, pk='1').render()
+ response = self.view(request, pk="1").render()
assert response.status_code == status.HTTP_200_OK
- assert dict(response.data) == {'id': 1, 'text': 'foobar'}
+ assert dict(response.data) == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+ assert updated.text == "foobar"
def test_patch_instance_view(self):
"""
PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
"""
- data = {'text': 'foobar'}
- request = factory.patch('/1', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.patch("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foobar'}
+ assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+ assert updated.text == "foobar"
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
- request = factory.delete('/1')
+ request = factory.delete("/1")
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT
- assert response.content == six.b('')
+ assert response.content == six.b("")
ids = [obj.id for obj in self.objects.all()]
assert ids == [2, 3]
@@ -254,23 +253,23 @@ class TestInstanceView(TestCase):
GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890.
"""
- request = factory.get('/a')
+ request = factory.get("/a")
with self.assertNumQueries(0):
- response = self.view(request, pk='a').render()
+ response = self.view(request, pk="a").render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
"""
PUT requests to create a new object should not be able to set the id.
"""
- data = {'id': 999, 'text': 'foobar'}
- request = factory.put('/1', data, format='json')
+ data = {"id": 999, "text": "foobar"}
+ request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foobar'}
+ assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
- assert updated.text == 'foobar'
+ assert updated.text == "foobar"
def test_put_to_deleted_instance(self):
"""
@@ -278,8 +277,8 @@ class TestInstanceView(TestCase):
an object does not currently exist.
"""
self.objects.get(id=1).delete()
- data = {'text': 'foobar'}
- request = factory.put('/1', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.put("/1", data, format="json")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -289,9 +288,9 @@ class TestInstanceView(TestCase):
PUT requests to an URL of instance which is filtered out should not be
able to create new objects.
"""
- data = {'text': 'foo'}
- filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
- request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
+ data = {"text": "foo"}
+ filtered_out_pk = BasicModel.objects.filter(text="filtered out")[0].pk
+ request = factory.put("/{0}".format(filtered_out_pk), data, format="json")
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -299,8 +298,8 @@ class TestInstanceView(TestCase):
"""
PATCH requests should not be able to create objects.
"""
- data = {'text': 'foobar'}
- request = factory.patch('/999', data, format='json')
+ data = {"text": "foobar"}
+ request = factory.patch("/999", data, format="json")
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -310,11 +309,11 @@ class TestInstanceView(TestCase):
"""
Incorrect PUT requests in HTML should include a form error.
"""
- data = {'text': 'foobar' * 100}
- request = factory.put('/', data, HTTP_ACCEPT='text/html')
+ data = {"text": "foobar" * 100}
+ request = factory.put("/", data, HTTP_ACCEPT="text/html")
response = self.view(request, pk=1).render()
expected_error = 'Ensure this field has no more than 100 characters.'
- assert expected_error in response.rendered_content.decode('utf-8')
+ assert expected_error in response.rendered_content.decode("utf-8")
class TestFKInstanceView(TestCase):
@@ -322,17 +321,14 @@ class TestFKInstanceView(TestCase):
"""
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ["foo", "bar", "baz"]
for item in items:
t = ForeignKeyTarget(name=item)
t.save()
- ForeignKeySource(name='source_' + item, target=t).save()
+ ForeignKeySource(name="source_" + item, target=t).save()
self.objects = ForeignKeySource.objects
- self.data = [
- {'id': obj.id, 'name': obj.name}
- for obj in self.objects.all()
- ]
+ self.data = [{"id": obj.id, "name": obj.name} for obj in self.objects.all()]
self.view = FKInstanceView.as_view()
@@ -346,23 +342,21 @@ class TestOverriddenGetObject(TestCase):
"""
Create 3 BasicModel instances.
"""
- items = ['foo', 'bar', 'baz']
+ items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
Example detail view for override of get_object().
"""
+
serializer_class = BasicSerializer
def get_object(self):
- pk = int(self.kwargs['pk'])
+ pk = int(self.kwargs["pk"])
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
@@ -371,7 +365,7 @@ class TestOverriddenGetObject(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
- request = factory.get('/1')
+ request = factory.get("/1")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
@@ -380,10 +374,11 @@ class TestOverriddenGetObject(TestCase):
# Regression test for #285
+
class CommentSerializer(serializers.ModelSerializer):
class Meta:
model = Comment
- exclude = ('created',)
+ exclude = ("created",)
class CommentView(generics.ListCreateAPIView):
@@ -402,12 +397,12 @@ class TestCreateModelWithAutoNowAddField(TestCase):
https://github.com/encode/django-rest-framework/issues/285
"""
- data = {'email': 'foobar@example.com', 'content': 'foobar'}
- request = factory.post('/', data, format='json')
+ data = {"email": "foobar@example.com", "content": "foobar"}
+ request = factory.post("/", data, format="json")
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
- assert created.content == 'foobar'
+ assert created.content == "foobar"
# Test for particularly ugly regression with m2m in browsable API
@@ -427,7 +422,7 @@ class ClassASerializer(serializers.ModelSerializer):
class Meta:
model = ClassA
- fields = '__all__'
+ fields = "__all__"
class ExampleView(generics.ListCreateAPIView):
@@ -440,7 +435,7 @@ class TestM2MBrowsableAPI(TestCase):
"""
Test for particularly ugly regression with m2m in browsable API
"""
- request = factory.get('/', HTTP_ACCEPT='text/html')
+ request = factory.get("/", HTTP_ACCEPT="text/html")
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
@@ -448,12 +443,12 @@ class TestM2MBrowsableAPI(TestCase):
class InclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='foo')
+ return queryset.filter(text="foo")
class ExclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view):
- return queryset.filter(text='other')
+ return queryset.filter(text="other")
class TwoFieldModel(models.Model):
@@ -466,16 +461,20 @@ class DynamicSerializerView(generics.ListCreateAPIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get_serializer_class(self):
- if self.request.method == 'POST':
+ if self.request.method == "POST":
+
class DynamicSerializer(serializers.ModelSerializer):
class Meta:
model = TwoFieldModel
- fields = ('field_b',)
+ fields = ("field_b",)
+
else:
+
class DynamicSerializer(serializers.ModelSerializer):
class Meta:
model = TwoFieldModel
- fields = '__all__'
+ fields = "__all__"
+
return DynamicSerializer
@@ -484,32 +483,29 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
Create 3 BasicModel instances to filter on.
"""
- items = ['foo', 'bar', 'baz']
+ items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
- self.data = [
- {'id': obj.id, 'text': obj.text}
- for obj in self.objects.all()
- ]
+ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/')
+ request = factory.get("/")
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
- assert response.data == [{'id': 1, 'text': 'foo'}]
+ assert response.data == [{"id": 1, "text": "foo"}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/')
+ request = factory.get("/")
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
@@ -519,31 +515,33 @@ class TestFilterBackendAppliedToViews(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
- request = factory.get('/1')
+ request = factory.get("/1")
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
- assert response.data == {'detail': 'Not found.'}
+ assert response.data == {"detail": "Not found."}
- def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
+ def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(
+ self
+ ):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
- request = factory.get('/1')
+ request = factory.get("/1")
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'id': 1, 'text': 'foo'}
+ assert response.data == {"id": 1, "text": "foo"}
def test_dynamic_serializer_form_in_browsable_api(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
view = DynamicSerializerView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request).render()
- content = response.content.decode('utf8')
- assert 'field_b' in content
- assert 'field_a' not in content
+ content = response.content.decode("utf8")
+ assert "field_b" in content
+ assert "field_a" not in content
class TestGuardedQueryset(TestCase):
@@ -555,21 +553,21 @@ class TestGuardedQueryset(TestCase):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
- request = factory.get('/')
+ request = factory.get("/")
with pytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
-
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockCreateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.post('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.post("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -578,9 +576,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockDestroyApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.delete('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -589,9 +588,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.patch('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -600,9 +600,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.get('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.get("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -611,9 +612,10 @@ class ApiViewsTests(TestCase):
def update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.put('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.put("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -622,9 +624,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockRetrieveUpdateApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.patch('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -633,9 +636,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockRetrieveDestroyUApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.get('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.get("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -644,9 +648,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
+
view = MockRetrieveDestroyUApiView()
- data = ('test request', ('test arg',), {'test_kwarg': 'test'})
- view.delete('test request', 'test arg', test_kwarg='test')
+ data = ("test request", ("test arg",), {"test_kwarg": "test"})
+ view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@@ -654,14 +659,12 @@ class ApiViewsTests(TestCase):
class GetObjectOr404Tests(TestCase):
def setUp(self):
super(GetObjectOr404Tests, self).setUp()
- self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
+ self.uuid_object = UUIDForeignKeyTarget.objects.create(name="bar")
def test_get_object_or_404_with_valid_uuid(self):
- obj = generics.get_object_or_404(
- UUIDForeignKeyTarget, pk=self.uuid_object.pk
- )
+ obj = generics.get_object_or_404(UUIDForeignKeyTarget, pk=self.uuid_object.pk)
assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self):
with pytest.raises(Http404):
- generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')
+ generics.get_object_or_404(UUIDForeignKeyTarget, pk="not-a-uuid")
diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py
index decd25a3f..25925e53e 100644
--- a/tests/test_htmlrenderer.py
+++ b/tests/test_htmlrenderer.py
@@ -15,40 +15,41 @@ from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
-@api_view(('GET',))
+@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
"""
- data = {'object': 'foobar'}
- return Response(data, template_name='example.html')
+ data = {"object": "foobar"}
+ return Response(data, template_name="example.html")
-@api_view(('GET',))
+@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def permission_denied(request):
raise PermissionDenied()
-@api_view(('GET',))
+@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def not_found(request):
raise Http404()
urlpatterns = [
- url(r'^$', example),
- url(r'^permission_denied$', permission_denied),
- url(r'^not_found$', not_found),
+ url(r"^$", example),
+ url(r"^permission_denied$", permission_denied),
+ url(r"^not_found$", not_found),
]
-@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
+@override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererTests(TestCase):
def setUp(self):
class MockResponse(object):
template_name = None
+
self.mock_response = MockResponse()
self._monkey_patch_get_template()
@@ -59,13 +60,13 @@ class TemplateHTMLRendererTests(TestCase):
self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None):
- if template_name == 'example.html':
- return engines['django'].from_string("example: {{ object }}")
+ if template_name == "example.html":
+ return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None):
- if template_name_list == ['example.html']:
- return engines['django'].from_string("example: {{ object }}")
+ if template_name_list == ["example.html"]:
+ return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template
@@ -78,29 +79,29 @@ class TemplateHTMLRendererTests(TestCase):
django.template.loader.get_template = self.get_template
def test_simple_html_view(self):
- response = self.client.get('/')
+ response = self.client.get("/")
self.assertContains(response, "example: foobar")
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_not_found_html_view(self):
- response = self.client.get('/not_found')
+ response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404 Not Found"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view(self):
- response = self.client.get('/permission_denied')
+ response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403 Forbidden"))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
# 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self):
renderer = TemplateHTMLRenderer()
- renderer.template_name = 'test_template'
+ renderer.template_name = "test_template"
template_name = renderer.get_template_names(self.mock_response, view={})
- assert template_name == ['test_template']
+ assert template_name == ["test_template"]
def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer()
@@ -110,18 +111,16 @@ class TemplateHTMLRendererTests(TestCase):
class MockView(object):
def get_template_names(self):
- return ['template from get_template_names method']
+ return ["template from get_template_names method"]
class MockView2(object):
- template_name = 'template from template_name attribute'
+ template_name = "template from template_name attribute"
- template_name = renderer.get_template_names(self.mock_response,
- MockView())
- assert template_name == ['template from get_template_names method']
+ template_name = renderer.get_template_names(self.mock_response, MockView())
+ assert template_name == ["template from get_template_names method"]
- template_name = renderer.get_template_names(self.mock_response,
- MockView2())
- assert template_name == ['template from template_name attribute']
+ template_name = renderer.get_template_names(self.mock_response, MockView2())
+ assert template_name == ["template from template_name attribute"]
def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer()
@@ -129,7 +128,7 @@ class TemplateHTMLRendererTests(TestCase):
renderer.get_template_names(self.mock_response, view=object())
-@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
+@override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererExceptionTests(TestCase):
def setUp(self):
"""
@@ -138,10 +137,10 @@ class TemplateHTMLRendererExceptionTests(TestCase):
self.get_template = django.template.loader.get_template
def get_template(template_name):
- if template_name == '404.html':
- return engines['django'].from_string("404: {{ detail }}")
- if template_name == '403.html':
- return engines['django'].from_string("403: {{ detail }}")
+ if template_name == "404.html":
+ return engines["django"].from_string("404: {{ detail }}")
+ if template_name == "403.html":
+ return engines["django"].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template
@@ -153,15 +152,18 @@ class TemplateHTMLRendererExceptionTests(TestCase):
django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self):
- response = self.client.get('/not_found')
+ response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
- self.assertTrue(response.content in (
- six.b("404: Not found"), six.b("404 Not Found")))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ self.assertTrue(
+ response.content in (six.b("404: Not found"), six.b("404 Not Found"))
+ )
+ self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view_with_template(self):
- response = self.client.get('/permission_denied')
+ response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertTrue(response.content in (
- six.b("403: Permission denied"), six.b("403 Forbidden")))
- self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
+ self.assertTrue(
+ response.content
+ in (six.b("403: Permission denied"), six.b("403 Forbidden"))
+ )
+ self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
diff --git a/tests/test_lazy_hyperlinks.py b/tests/test_lazy_hyperlinks.py
index cf3ee735f..c46ee5efa 100644
--- a/tests/test_lazy_hyperlinks.py
+++ b/tests/test_lazy_hyperlinks.py
@@ -6,6 +6,7 @@ from rest_framework import serializers
from rest_framework.renderers import JSONRenderer
from rest_framework.templatetags.rest_framework import format_value
+
str_called = False
@@ -15,35 +16,33 @@ class Example(models.Model):
def __str__(self):
global str_called
str_called = True
- return 'An example'
+ return "An example"
class ExampleSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = Example
- fields = ('url', 'id', 'text')
+ fields = ("url", "id", "text")
def dummy_view(request):
pass
-urlpatterns = [
- url(r'^example/(?P[0-9]+)/$', dummy_view, name='example-detail'),
-]
+urlpatterns = [url(r"^example/(?P[0-9]+)/$", dummy_view, name="example-detail")]
-@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
+@override_settings(ROOT_URLCONF="tests.test_lazy_hyperlinks")
class TestLazyHyperlinkNames(TestCase):
def setUp(self):
- self.example = Example.objects.create(text='foo')
+ self.example = Example.objects.create(text="foo")
def test_lazy_hyperlink_names(self):
global str_called
- context = {'request': None}
+ context = {"request": None}
serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data)
assert not str_called
- hyperlink_string = format_value(serializer.data['url'])
- assert hyperlink_string == 'An example'
+ hyperlink_string = format_value(serializer.data["url"])
+ assert hyperlink_string == "An example"
assert str_called
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index fe4ea4b42..7e77d11d5 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -5,19 +5,17 @@ from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from django.test import TestCase
-from rest_framework import (
- exceptions, metadata, serializers, status, versioning, views
-)
+from rest_framework import exceptions, metadata, serializers, status, versioning, views
from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.test import APIRequestFactory
from .models import BasicModel
-request = APIRequestFactory().options('/')
+
+request = APIRequestFactory().options("/")
class TestMetadata:
-
def test_determine_metadata_abstract_method_raises_proper_error(self):
with pytest.raises(NotImplementedError):
metadata.BaseMetadata().determine_metadata(None, None)
@@ -26,24 +24,23 @@ class TestMetadata:
"""
OPTIONS requests to views should return a valid 200 response.
"""
+
class ExampleView(views.APIView):
"""Example view."""
+
pass
view = ExampleView.as_view()
response = view(request=request)
expected = {
- 'name': 'Example',
- 'description': 'Example view.',
- 'renders': [
- 'application/json',
- 'text/html'
+ "name": "Example",
+ "description": "Example view.",
+ "renders": ["application/json", "text/html"],
+ "parses": [
+ "application/json",
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
],
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ]
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
@@ -53,41 +50,40 @@ class TestMetadata:
OPTIONS requests to views where `metadata_class = None` should raise
a MethodNotAllowed exception, which will result in an HTTP 405 response.
"""
+
class ExampleView(views.APIView):
metadata_class = None
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
- assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}
+ assert response.data == {"detail": 'Method "OPTIONS" not allowed.'}
def test_actions(self):
"""
On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests.
"""
+
class NestedField(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class ExampleSerializer(serializers.Serializer):
- choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
- integer_field = serializers.IntegerField(
- min_value=1, max_value=1000
- )
+ choice_field = serializers.ChoiceField(["red", "green", "blue"])
+ integer_field = serializers.IntegerField(min_value=1, max_value=1000)
char_field = serializers.CharField(
required=False, min_length=3, max_length=40
)
list_field = serializers.ListField(
- child=serializers.ListField(
- child=serializers.IntegerField()
- )
+ child=serializers.ListField(child=serializers.IntegerField())
)
nested_field = NestedField()
uuid_field = serializers.UUIDField(label="UUID field")
class ExampleView(views.APIView):
"""Example view."""
+
def post(self, request):
pass
@@ -97,91 +93,87 @@ class TestMetadata:
view = ExampleView.as_view()
response = view(request=request)
expected = {
- 'name': 'Example',
- 'description': 'Example view.',
- 'renders': [
- 'application/json',
- 'text/html'
+ "name": "Example",
+ "description": "Example view.",
+ "renders": ["application/json", "text/html"],
+ "parses": [
+ "application/json",
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
],
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'actions': {
- 'POST': {
- 'choice_field': {
- 'type': 'choice',
- 'required': True,
- 'read_only': False,
- 'label': 'Choice field',
- 'choices': [
- {'display_name': 'red', 'value': 'red'},
- {'display_name': 'green', 'value': 'green'},
- {'display_name': 'blue', 'value': 'blue'}
- ]
+ "actions": {
+ "POST": {
+ "choice_field": {
+ "type": "choice",
+ "required": True,
+ "read_only": False,
+ "label": "Choice field",
+ "choices": [
+ {"display_name": "red", "value": "red"},
+ {"display_name": "green", "value": "green"},
+ {"display_name": "blue", "value": "blue"},
+ ],
},
- 'integer_field': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False,
- 'label': 'Integer field',
- 'min_value': 1,
- 'max_value': 1000,
-
+ "integer_field": {
+ "type": "integer",
+ "required": True,
+ "read_only": False,
+ "label": "Integer field",
+ "min_value": 1,
+ "max_value": 1000,
},
- 'char_field': {
- 'type': 'string',
- 'required': False,
- 'read_only': False,
- 'label': 'Char field',
- 'min_length': 3,
- 'max_length': 40
+ "char_field": {
+ "type": "string",
+ "required": False,
+ "read_only": False,
+ "label": "Char field",
+ "min_length": 3,
+ "max_length": 40,
},
- 'list_field': {
- 'type': 'list',
- 'required': True,
- 'read_only': False,
- 'label': 'List field',
- 'child': {
- 'type': 'list',
- 'required': True,
- 'read_only': False,
- 'child': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False
- }
- }
- },
- 'nested_field': {
- 'type': 'nested object',
- 'required': True,
- 'read_only': False,
- 'label': 'Nested field',
- 'children': {
- 'a': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False,
- 'label': 'A'
+ "list_field": {
+ "type": "list",
+ "required": True,
+ "read_only": False,
+ "label": "List field",
+ "child": {
+ "type": "list",
+ "required": True,
+ "read_only": False,
+ "child": {
+ "type": "integer",
+ "required": True,
+ "read_only": False,
},
- 'b': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False,
- 'label': 'B'
- }
- }
+ },
},
- 'uuid_field': {
+ "nested_field": {
+ "type": "nested object",
+ "required": True,
+ "read_only": False,
+ "label": "Nested field",
+ "children": {
+ "a": {
+ "type": "integer",
+ "required": True,
+ "read_only": False,
+ "label": "A",
+ },
+ "b": {
+ "type": "integer",
+ "required": True,
+ "read_only": False,
+ "label": "B",
+ },
+ },
+ },
+ "uuid_field": {
"type": "string",
"required": True,
"read_only": False,
"label": "UUID field",
},
}
- }
+ },
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
@@ -191,13 +183,15 @@ class TestMetadata:
If a user does not have global permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
+
class ExampleSerializer(serializers.Serializer):
- choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
+ choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
+
def post(self, request):
pass
@@ -208,26 +202,28 @@ class TestMetadata:
return ExampleSerializer()
def check_permissions(self, request):
- if request.method == 'POST':
+ if request.method == "POST":
raise exceptions.PermissionDenied()
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
- assert list(response.data['actions']) == ['PUT']
+ assert list(response.data["actions"]) == ["PUT"]
def test_object_permissions(self):
"""
If a user does not have object permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
+
class ExampleSerializer(serializers.Serializer):
- choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
+ choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
+
def post(self, request):
pass
@@ -238,13 +234,13 @@ class TestMetadata:
return ExampleSerializer()
def get_object(self):
- if self.request.method == 'PUT':
+ if self.request.method == "PUT":
raise exceptions.PermissionDenied()
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
- assert list(response.data['actions'].keys()) == ['POST']
+ assert list(response.data["actions"].keys()) == ["POST"]
def test_bug_2455_clone_request(self):
class ExampleView(views.APIView):
@@ -254,7 +250,7 @@ class TestMetadata:
pass
def get_serializer(self):
- assert hasattr(self.request, 'version')
+ assert hasattr(self.request, "version")
return serializers.Serializer()
view = ExampleView.as_view()
@@ -268,7 +264,7 @@ class TestMetadata:
pass
def get_serializer(self):
- assert hasattr(self.request, 'versioning_scheme')
+ assert hasattr(self.request, "versioning_scheme")
return serializers.Serializer()
scheme = versioning.QueryParameterVersioning
@@ -279,12 +275,14 @@ class TestMetadata:
"""
HiddenField shouldn't show up in SimpleMetadata at all.
"""
+
class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10)
hidden_field = serializers.HiddenField(default=1)
class ExampleView(views.APIView):
"""Example view."""
+
def post(self, request):
pass
@@ -294,9 +292,11 @@ class TestMetadata:
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
- assert set(response.data['actions']['POST'].keys()) == {'integer_field'}
+ assert set(response.data["actions"]["POST"].keys()) == {"integer_field"}
- def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(self):
+ def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(
+ self
+ ):
class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
@@ -307,14 +307,16 @@ class TestMetadata:
options = metadata.SimpleMetadata()
child_serializer = ExampleSerializer()
list_serializer = ExampleListSerializer(child=child_serializer)
- assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
+ assert options.get_serializer_info(
+ list_serializer
+ ) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField())
- assert field_info['type'] == 'boolean'
+ assert field_info["type"] == "boolean"
def test_related_field_choices(self):
options = metadata.SimpleMetadata()
@@ -323,7 +325,7 @@ class TestSimpleMetadataFieldInfo(TestCase):
field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all())
)
- assert 'choices' not in field_info
+ assert "choices" not in field_info
class TestModelSerializerMetadata(TestCase):
@@ -333,9 +335,12 @@ class TestModelSerializerMetadata(TestCase):
on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present
"""
+
class Parent(models.Model):
- integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
- children = models.ManyToManyField('Child')
+ integer_field = models.IntegerField(
+ validators=[MinValueValidator(1), MaxValueValidator(1000)]
+ )
+ children = models.ManyToManyField("Child")
name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model):
@@ -346,10 +351,11 @@ class TestModelSerializerMetadata(TestCase):
class Meta:
model = Parent
- fields = '__all__'
+ fields = "__all__"
class ExampleView(views.APIView):
"""Example view."""
+
def post(self, request):
pass
@@ -359,48 +365,45 @@ class TestModelSerializerMetadata(TestCase):
view = ExampleView.as_view()
response = view(request=request)
expected = {
- 'name': 'Example',
- 'description': 'Example view.',
- 'renders': [
- 'application/json',
- 'text/html'
+ "name": "Example",
+ "description": "Example view.",
+ "renders": ["application/json", "text/html"],
+ "parses": [
+ "application/json",
+ "application/x-www-form-urlencoded",
+ "multipart/form-data",
],
- 'parses': [
- 'application/json',
- 'application/x-www-form-urlencoded',
- 'multipart/form-data'
- ],
- 'actions': {
- 'POST': {
- 'id': {
- 'type': 'integer',
- 'required': False,
- 'read_only': True,
- 'label': 'ID'
+ "actions": {
+ "POST": {
+ "id": {
+ "type": "integer",
+ "required": False,
+ "read_only": True,
+ "label": "ID",
},
- 'children': {
- 'type': 'field',
- 'required': False,
- 'read_only': True,
- 'label': 'Children'
+ "children": {
+ "type": "field",
+ "required": False,
+ "read_only": True,
+ "label": "Children",
},
- 'integer_field': {
- 'type': 'integer',
- 'required': True,
- 'read_only': False,
- 'label': 'Integer field',
- 'min_value': 1,
- 'max_value': 1000
+ "integer_field": {
+ "type": "integer",
+ "required": True,
+ "read_only": False,
+ "label": "Integer field",
+ "min_value": 1,
+ "max_value": 1000,
+ },
+ "name": {
+ "type": "string",
+ "required": False,
+ "read_only": False,
+ "label": "Name",
+ "max_length": 100,
},
- 'name': {
- 'type': 'string',
- 'required': False,
- 'read_only': False,
- 'label': 'Name',
- 'max_length': 100
- }
}
- }
+ },
}
assert response.status_code == status.HTTP_200_OK
diff --git a/tests/test_middleware.py b/tests/test_middleware.py
index 9df7d8e3e..66cb8ebc2 100644
--- a/tests/test_middleware.py
+++ b/tests/test_middleware.py
@@ -17,8 +17,8 @@ class PostView(APIView):
urlpatterns = [
- url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))),
- url(r'^post$', PostView.as_view()),
+ url(r"^auth$", APIView.as_view(authentication_classes=(TokenAuthentication,))),
+ url(r"^post$", PostView.as_view()),
]
@@ -28,8 +28,8 @@ class RequestUserMiddleware(object):
def __call__(self, request):
response = self.get_response(request)
- assert hasattr(request, 'user'), '`user` is not set on request'
- assert request.user.is_authenticated, '`user` is not authenticated'
+ assert hasattr(request, "user"), "`user` is not set on request"
+ assert request.user.is_authenticated, "`user` is not authenticated"
return response
@@ -49,28 +49,27 @@ class RequestPOSTMiddleware(object):
# Ensure request.POST is set as appropriate
if is_form_media_type(request.content_type):
- assert request.POST == {'foo': ['bar']}
+ assert request.POST == {"foo": ["bar"]}
else:
assert request.POST == {}
return response
-@override_settings(ROOT_URLCONF='tests.test_middleware')
+@override_settings(ROOT_URLCONF="tests.test_middleware")
class TestMiddleware(APITestCase):
-
- @override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',))
+ @override_settings(MIDDLEWARE=("tests.test_middleware.RequestUserMiddleware",))
def test_middleware_can_access_user_when_processing_response(self):
- user = User.objects.create_user('john', 'john@example.com', 'password')
- key = 'abcd1234'
+ user = User.objects.create_user("john", "john@example.com", "password")
+ key = "abcd1234"
Token.objects.create(key=key, user=user)
- self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key)
+ self.client.get("/auth", HTTP_AUTHORIZATION="Token %s" % key)
- @override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',))
+ @override_settings(MIDDLEWARE=("tests.test_middleware.RequestPOSTMiddleware",))
def test_middleware_can_access_request_post_when_processing_response(self):
- response = self.client.post('/post', {'foo': 'bar'})
+ response = self.client.post("/post", {"foo": "bar"})
assert response.status_code == 200
- response = self.client.post('/post', {'foo': 'bar'}, format='json')
+ response = self.client.post("/post", {"foo": "bar"}, format="json")
assert response.status_code == 200
diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py
index 898c859a4..91722b3f7 100644
--- a/tests/test_model_serializer.py
+++ b/tests/test_model_serializer.py
@@ -16,7 +16,9 @@ import django
import pytest
from django.core.exceptions import ImproperlyConfigured
from django.core.validators import (
- MaxValueValidator, MinLengthValidator, MinValueValidator
+ MaxValueValidator,
+ MinLengthValidator,
+ MinValueValidator,
)
from django.db import models
from django.test import TestCase
@@ -29,16 +31,18 @@ from .models import NestedForeignKeySource
def dedent(blocktext):
- return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
+ return "\n".join([line[12:] for line in blocktext.splitlines()[1:-1]])
# Tests for regular field mappings.
# ---------------------------------
+
class CustomField(models.Field):
"""
A custom model field simply for testing purposes.
"""
+
pass
@@ -50,6 +54,7 @@ class RegularFieldsModel(models.Model):
"""
A model class for testing regular flat fields.
"""
+
auto_field = models.AutoField(primary_key=True)
big_integer_field = models.BigIntegerField()
boolean_field = models.BooleanField(default=False)
@@ -71,28 +76,40 @@ class RegularFieldsModel(models.Model):
time_field = models.TimeField()
url_field = models.URLField(max_length=100)
custom_field = CustomField()
- file_path_field = models.FilePathField(path='/tmp/')
+ file_path_field = models.FilePathField(path="/tmp/")
def method(self):
- return 'method'
+ return "method"
-COLOR_CHOICES = (('red', 'Red'), ('blue', 'Blue'), ('green', 'Green'))
-DECIMAL_CHOICES = (('low', decimal.Decimal('0.1')), ('medium', decimal.Decimal('0.5')), ('high', decimal.Decimal('0.9')))
+COLOR_CHOICES = (("red", "Red"), ("blue", "Blue"), ("green", "Green"))
+DECIMAL_CHOICES = (
+ ("low", decimal.Decimal("0.1")),
+ ("medium", decimal.Decimal("0.5")),
+ ("high", decimal.Decimal("0.9")),
+)
class FieldOptionsModel(models.Model):
- value_limit_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(10)])
- length_limit_field = models.CharField(validators=[MinLengthValidator(3)], max_length=12)
+ value_limit_field = models.IntegerField(
+ validators=[MinValueValidator(1), MaxValueValidator(10)]
+ )
+ length_limit_field = models.CharField(
+ validators=[MinLengthValidator(3)], max_length=12
+ )
blank_field = models.CharField(blank=True, max_length=10)
null_field = models.IntegerField(null=True)
default_field = models.IntegerField(default=0)
- descriptive_field = models.IntegerField(help_text='Some help text', verbose_name='A label')
+ descriptive_field = models.IntegerField(
+ help_text="Some help text", verbose_name="A label"
+ )
choices_field = models.CharField(max_length=100, choices=COLOR_CHOICES)
class ChoicesModel(models.Model):
- choices_field_with_nonstandard_args = models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES, verbose_name='A label')
+ choices_field_with_nonstandard_args = models.DecimalField(
+ max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES, verbose_name="A label"
+ )
class Issue3674ParentModel(models.Model):
@@ -100,15 +117,14 @@ class Issue3674ParentModel(models.Model):
class Issue3674ChildModel(models.Model):
- parent = models.ForeignKey(Issue3674ParentModel, related_name='children', on_delete=models.CASCADE)
+ parent = models.ForeignKey(
+ Issue3674ParentModel, related_name="children", on_delete=models.CASCADE
+ )
value = models.CharField(primary_key=True, max_length=64)
class UniqueChoiceModel(models.Model):
- CHOICES = (
- ('choice1', 'choice 1'),
- ('choice2', 'choice 1'),
- )
+ CHOICES = (("choice1", "choice 1"), ("choice2", "choice 1"))
name = models.CharField(max_length=254, unique=True, choices=CHOICES)
@@ -120,15 +136,14 @@ class TestModelSerializer(TestCase):
class Meta:
model = OneFieldModel
- fields = ('char_field', 'non_model_field')
+ fields = ("char_field", "non_model_field")
- serializer = TestSerializer(data={
- 'char_field': 'foo',
- 'non_model_field': 'bar',
- })
+ serializer = TestSerializer(
+ data={"char_field": "foo", "non_model_field": "bar"}
+ )
serializer.is_valid()
- msginitial = 'Got a `TypeError` when calling `OneFieldModel.objects.create()`.'
+ msginitial = "Got a `TypeError` when calling `OneFieldModel.objects.create()`."
with self.assertRaisesMessage(TypeError, msginitial):
serializer.save()
@@ -137,6 +152,7 @@ class TestModelSerializer(TestCase):
Test that trying to use ModelSerializer with Abstract Models
throws a ValueError exception.
"""
+
class AbstractModel(models.Model):
afield = models.CharField(max_length=255)
@@ -146,13 +162,11 @@ class TestModelSerializer(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = AbstractModel
- fields = ('afield',)
+ fields = ("afield",)
- serializer = TestSerializer(data={
- 'afield': 'foo',
- })
+ serializer = TestSerializer(data={"afield": "foo"})
- msginitial = 'Cannot use ModelSerializer with Abstract Models.'
+ msginitial = "Cannot use ModelSerializer with Abstract Models."
with self.assertRaisesMessage(ValueError, msginitial):
serializer.is_valid()
@@ -162,12 +176,14 @@ class TestRegularFieldMappings(TestCase):
"""
Model fields should map to their equivalent serializer fields.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
auto_field = IntegerField(read_only=True)
big_integer_field = IntegerField()
@@ -191,7 +207,8 @@ class TestRegularFieldMappings(TestCase):
url_field = URLField(max_length=100)
custom_field = ModelField(model_field=)
file_path_field = FilePathField(path='/tmp/')
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -199,9 +216,10 @@ class TestRegularFieldMappings(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = FieldOptionsModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
value_limit_field = IntegerField(max_value=10, min_value=1)
@@ -211,19 +229,20 @@ class TestRegularFieldMappings(TestCase):
default_field = IntegerField(required=False)
descriptive_field = IntegerField(help_text='Some help text', label='A label')
choices_field = ChoiceField(choices=(('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')))
- """)
+ """
+ )
if six.PY2:
# This particular case is too awkward to resolve fully across
# both py2 and py3.
expected = expected.replace(
"('red', 'Red'), ('blue', 'Blue'), ('green', 'Green')",
- "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')"
+ "(u'red', u'Red'), (u'blue', u'Blue'), (u'green', u'Green')",
)
self.assertEqual(unicode_repr(TestSerializer()), expected)
# merge this into test_regular_fields / RegularFieldsModel when
# Django 2.1 is the minimum supported version
- @pytest.mark.skipif(django.VERSION < (2, 1), reason='Django version < 2.1')
+ @pytest.mark.skipif(django.VERSION < (2, 1), reason="Django version < 2.1")
def test_nullable_boolean_field(self):
class NullableBooleanModel(models.Model):
field = models.BooleanField(null=True, default=False)
@@ -231,12 +250,14 @@ class TestRegularFieldMappings(TestCase):
class NullableBooleanSerializer(serializers.ModelSerializer):
class Meta:
model = NullableBooleanModel
- fields = ['field']
+ fields = ["field"]
- expected = dedent("""
+ expected = dedent(
+ """
NullableBooleanSerializer():
field = BooleanField(allow_null=True, required=False)
- """)
+ """
+ )
self.assertEqual(unicode_repr(NullableBooleanSerializer()), expected)
@@ -245,66 +266,78 @@ class TestRegularFieldMappings(TestCase):
Properties and methods on the model should be allowed as `Meta.fields`
values, and should map to `ReadOnlyField`.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('auto_field', 'method')
+ fields = ("auto_field", "method")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
auto_field = IntegerField(read_only=True)
method = ReadOnlyField()
- """)
+ """
+ )
self.assertEqual(repr(TestSerializer()), expected)
def test_pk_fields(self):
"""
Both `pk` and the actual primary key name are valid in `Meta.fields`.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('pk', 'auto_field')
+ fields = ("pk", "auto_field")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
pk = IntegerField(label='Auto field', read_only=True)
auto_field = IntegerField(read_only=True)
- """)
+ """
+ )
self.assertEqual(repr(TestSerializer()), expected)
def test_extra_field_kwargs(self):
"""
Ensure `extra_kwargs` are passed to generated fields.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('auto_field', 'char_field')
- extra_kwargs = {'char_field': {'default': 'extra'}}
+ fields = ("auto_field", "char_field")
+ extra_kwargs = {"char_field": {"default": "extra"}}
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
auto_field = IntegerField(read_only=True)
char_field = CharField(default='extra', max_length=100)
- """)
+ """
+ )
self.assertEqual(repr(TestSerializer()), expected)
def test_extra_field_kwargs_required(self):
"""
Ensure `extra_kwargs` are passed to generated fields.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('auto_field', 'char_field')
- extra_kwargs = {'auto_field': {'required': False, 'read_only': False}}
+ fields = ("auto_field", "char_field")
+ extra_kwargs = {"auto_field": {"required": False, "read_only": False}}
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
auto_field = IntegerField(read_only=False, required=False)
char_field = CharField(max_length=100)
- """)
+ """
+ )
self.assertEqual(repr(TestSerializer()), expected)
def test_invalid_field(self):
@@ -312,12 +345,13 @@ class TestRegularFieldMappings(TestCase):
Field names that do not map to a model field or relationship should
raise a configuration errror.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('auto_field', 'invalid')
+ fields = ("auto_field", "invalid")
- expected = 'Field name `invalid` is not valid for model `RegularFieldsModel`.'
+ expected = "Field name `invalid` is not valid for model `RegularFieldsModel`."
with self.assertRaisesMessage(ImproperlyConfigured, expected):
TestSerializer().fields
@@ -326,12 +360,13 @@ class TestRegularFieldMappings(TestCase):
Fields that have been declared on the serializer class must be included
in the `Meta.fields` if it exists.
"""
+
class TestSerializer(serializers.ModelSerializer):
missing = serializers.ReadOnlyField()
class Meta:
model = RegularFieldsModel
- fields = ('auto_field',)
+ fields = ("auto_field",)
expected = (
"The field 'missing' was declared on serializer TestSerializer, "
@@ -345,13 +380,14 @@ class TestRegularFieldMappings(TestCase):
Fields that have been declared on a parent of the serializer class may
be excluded from the `Meta.fields` option.
"""
+
class TestSerializer(serializers.ModelSerializer):
missing = serializers.ReadOnlyField()
class ChildSerializer(TestSerializer):
class Meta:
model = RegularFieldsModel
- fields = ('auto_field',)
+ fields = ("auto_field",)
ChildSerializer().fields
@@ -359,7 +395,7 @@ class TestRegularFieldMappings(TestCase):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = ChoicesModel
- fields = '__all__'
+ fields = "__all__"
ExampleSerializer()
@@ -370,18 +406,21 @@ class TestDurationFieldMapping(TestCase):
"""
A model that defines DurationField.
"""
+
duration_field = models.DurationField()
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = DurationFieldModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField()
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_duration_field_with_validators(self):
@@ -389,24 +428,36 @@ class TestDurationFieldMapping(TestCase):
"""
A model that defines DurationField with validators.
"""
+
duration_field = models.DurationField(
- validators=[MinValueValidator(datetime.timedelta(days=1)), MaxValueValidator(datetime.timedelta(days=3))]
+ validators=[
+ MinValueValidator(datetime.timedelta(days=1)),
+ MaxValueValidator(datetime.timedelta(days=3)),
+ ]
)
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ValidatedDurationFieldModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = (
+ dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(3), min_value=datetime.timedelta(1))
- """) if sys.version_info < (3, 7) else dedent("""
+ """
+ )
+ if sys.version_info < (3, 7)
+ else dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
duration_field = DurationField(max_value=datetime.timedelta(days=3), min_value=datetime.timedelta(days=1))
- """)
+ """
+ )
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -418,16 +469,18 @@ class TestGenericIPAddressFieldValidation(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = IPAddressFieldModel
- fields = '__all__'
+ fields = "__all__"
- s = TestSerializer(data={'address': 'not an ip address'})
+ s = TestSerializer(data={"address": "not an ip address"})
self.assertFalse(s.is_valid())
- self.assertEqual(1, len(s.errors['address']),
- 'Unexpected number of validation errors: '
- '{0}'.format(s.errors))
+ self.assertEqual(
+ 1,
+ len(s.errors["address"]),
+ "Unexpected number of validation errors: " "{0}".format(s.errors),
+ )
-@pytest.mark.skipif('not postgres_fields')
+@pytest.mark.skipif("not postgres_fields")
class TestPosgresFieldsMapping(TestCase):
def test_hstore_field(self):
class HStoreFieldModel(models.Model):
@@ -436,12 +489,14 @@ class TestPosgresFieldsMapping(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = HStoreFieldModel
- fields = ['hstore_field']
+ fields = ["hstore_field"]
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
hstore_field = HStoreField()
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_array_field(self):
@@ -451,12 +506,14 @@ class TestPosgresFieldsMapping(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ArrayFieldModel
- fields = ['array_field']
+ fields = ["array_field"]
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
array_field = ListField(child=CharField(label='Array field', validators=[]))
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_json_field(self):
@@ -466,18 +523,21 @@ class TestPosgresFieldsMapping(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = JSONFieldModel
- fields = ['json_field']
+ fields = ["json_field"]
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
json_field = JSONField(style={'base_template': 'textarea.html'})
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
# Tests for relational field mappings.
# ------------------------------------
+
class ForeignKeyTargetModel(models.Model):
name = models.CharField(max_length=100)
@@ -496,20 +556,36 @@ class ThroughTargetModel(models.Model):
class Supplementary(models.Model):
extra = models.IntegerField()
- forwards = models.ForeignKey('ThroughTargetModel', on_delete=models.CASCADE)
- backwards = models.ForeignKey('RelationalModel', on_delete=models.CASCADE)
+ forwards = models.ForeignKey("ThroughTargetModel", on_delete=models.CASCADE)
+ backwards = models.ForeignKey("RelationalModel", on_delete=models.CASCADE)
class RelationalModel(models.Model):
- foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='reverse_foreign_key', on_delete=models.CASCADE)
- many_to_many = models.ManyToManyField(ManyToManyTargetModel, related_name='reverse_many_to_many')
- one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='reverse_one_to_one', on_delete=models.CASCADE)
- through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
+ foreign_key = models.ForeignKey(
+ ForeignKeyTargetModel,
+ related_name="reverse_foreign_key",
+ on_delete=models.CASCADE,
+ )
+ many_to_many = models.ManyToManyField(
+ ManyToManyTargetModel, related_name="reverse_many_to_many"
+ )
+ one_to_one = models.OneToOneField(
+ OneToOneTargetModel, related_name="reverse_one_to_one", on_delete=models.CASCADE
+ )
+ through = models.ManyToManyField(
+ ThroughTargetModel, through=Supplementary, related_name="reverse_through"
+ )
class UniqueTogetherModel(models.Model):
- foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='unique_foreign_key', on_delete=models.CASCADE)
- one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='unique_one_to_one', on_delete=models.CASCADE)
+ foreign_key = models.ForeignKey(
+ ForeignKeyTargetModel,
+ related_name="unique_foreign_key",
+ on_delete=models.CASCADE,
+ )
+ one_to_one = models.OneToOneField(
+ OneToOneTargetModel, related_name="unique_one_to_one", on_delete=models.CASCADE
+ )
class Meta:
unique_together = ("foreign_key", "one_to_one")
@@ -520,16 +596,18 @@ class TestRelationalFieldMappings(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = PrimaryKeyRelatedField(queryset=ForeignKeyTargetModel.objects.all())
one_to_one = PrimaryKeyRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[])
many_to_many = PrimaryKeyRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all())
through = PrimaryKeyRelatedField(many=True, read_only=True)
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_nested_relations(self):
@@ -537,9 +615,10 @@ class TestRelationalFieldMappings(TestCase):
class Meta:
model = RelationalModel
depth = 1
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = NestedSerializer(read_only=True):
@@ -554,23 +633,26 @@ class TestRelationalFieldMappings(TestCase):
through = NestedSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = HyperlinkedRelatedField(queryset=ForeignKeyTargetModel.objects.all(), view_name='foreignkeytargetmodel-detail')
one_to_one = HyperlinkedRelatedField(queryset=OneToOneTargetModel.objects.all(), validators=[], view_name='onetoonetargetmodel-detail')
many_to_many = HyperlinkedRelatedField(allow_empty=False, many=True, queryset=ManyToManyTargetModel.objects.all(), view_name='manytomanytargetmodel-detail')
through = HyperlinkedRelatedField(many=True, read_only=True, view_name='throughtargetmodel-detail')
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_nested_hyperlinked_relations(self):
@@ -578,9 +660,10 @@ class TestRelationalFieldMappings(TestCase):
class Meta:
model = RelationalModel
depth = 1
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedSerializer(read_only=True):
@@ -595,7 +678,8 @@ class TestRelationalFieldMappings(TestCase):
through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_nested_hyperlinked_relations_starred_source(self):
@@ -603,14 +687,12 @@ class TestRelationalFieldMappings(TestCase):
class Meta:
model = RelationalModel
depth = 1
- fields = '__all__'
+ fields = "__all__"
- extra_kwargs = {
- 'url': {
- 'source': '*',
- }}
+ extra_kwargs = {"url": {"source": "*"}}
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
url = HyperlinkedIdentityField(source='*', view_name='relationalmodel-detail')
foreign_key = NestedSerializer(read_only=True):
@@ -625,7 +707,8 @@ class TestRelationalFieldMappings(TestCase):
through = NestedSerializer(many=True, read_only=True):
url = HyperlinkedIdentityField(view_name='throughtargetmodel-detail')
name = CharField(max_length=100)
- """)
+ """
+ )
self.maxDiff = None
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -634,9 +717,10 @@ class TestRelationalFieldMappings(TestCase):
class Meta:
model = UniqueTogetherModel
depth = 1
- fields = '__all__'
+ fields = "__all__"
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
url = HyperlinkedIdentityField(view_name='uniquetogethermodel-detail')
foreign_key = NestedSerializer(read_only=True):
@@ -645,13 +729,13 @@ class TestRelationalFieldMappings(TestCase):
one_to_one = NestedSerializer(read_only=True):
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
name = CharField(max_length=100)
- """)
+ """
+ )
if six.PY2:
# This case is also too awkward to resolve fully across both py2
# and py3. (See above)
expected = expected.replace(
- "('foreign_key', 'one_to_one')",
- "(u'foreign_key', u'one_to_one')"
+ "('foreign_key', 'one_to_one')", "(u'foreign_key', u'one_to_one')"
)
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -659,56 +743,64 @@ class TestRelationalFieldMappings(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeyTargetModel
- fields = ('id', 'name', 'reverse_foreign_key')
+ fields = ("id", "name", "reverse_foreign_key")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_foreign_key = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneTargetModel
- fields = ('id', 'name', 'reverse_one_to_one')
+ fields = ("id", "name", "reverse_one_to_one")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_one_to_one = PrimaryKeyRelatedField(queryset=RelationalModel.objects.all())
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManyTargetModel
- fields = ('id', 'name', 'reverse_many_to_many')
+ fields = ("id", "name", "reverse_many_to_many")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_many_to_many = PrimaryKeyRelatedField(many=True, queryset=RelationalModel.objects.all())
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ThroughTargetModel
- fields = ('id', 'name', 'reverse_through')
+ fields = ("id", "name", "reverse_through")
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
name = CharField(max_length=100)
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -716,7 +808,7 @@ class DisplayValueTargetModel(models.Model):
name = models.CharField(max_length=100)
def __str__(self):
- return '%s Color' % (self.name)
+ return "%s Color" % (self.name)
class DisplayValueModel(models.Model):
@@ -725,55 +817,57 @@ class DisplayValueModel(models.Model):
class TestRelationalFieldDisplayValue(TestCase):
def setUp(self):
- DisplayValueTargetModel.objects.bulk_create([
- DisplayValueTargetModel(name='Red'),
- DisplayValueTargetModel(name='Yellow'),
- DisplayValueTargetModel(name='Green'),
- ])
+ DisplayValueTargetModel.objects.bulk_create(
+ [
+ DisplayValueTargetModel(name="Red"),
+ DisplayValueTargetModel(name="Yellow"),
+ DisplayValueTargetModel(name="Green"),
+ ]
+ )
def test_default_display_value(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = DisplayValueModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer()
- expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')])
- self.assertEqual(serializer.fields['color'].choices, expected)
+ expected = OrderedDict(
+ [(1, "Red Color"), (2, "Yellow Color"), (3, "Green Color")]
+ )
+ self.assertEqual(serializer.fields["color"].choices, expected)
def test_custom_display_value(self):
class TestField(serializers.PrimaryKeyRelatedField):
def display_value(self, instance):
- return 'My %s Color' % (instance.name)
+ return "My %s Color" % (instance.name)
class TestSerializer(serializers.ModelSerializer):
color = TestField(queryset=DisplayValueTargetModel.objects.all())
class Meta:
model = DisplayValueModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer()
- expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')])
- self.assertEqual(serializer.fields['color'].choices, expected)
+ expected = OrderedDict(
+ [(1, "My Red Color"), (2, "My Yellow Color"), (3, "My Green Color")]
+ )
+ self.assertEqual(serializer.fields["color"].choices, expected)
class TestIntegration(TestCase):
def setUp(self):
self.foreign_key_target = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- self.one_to_one_target = OneToOneTargetModel.objects.create(
- name='one_to_one'
+ name="foreign_key"
)
+ self.one_to_one_target = OneToOneTargetModel.objects.create(name="one_to_one")
self.many_to_many_targets = [
- ManyToManyTargetModel.objects.create(
- name='many_to_many (%d)' % idx
- ) for idx in range(3)
+ ManyToManyTargetModel.objects.create(name="many_to_many (%d)" % idx)
+ for idx in range(3)
]
self.instance = RelationalModel.objects.create(
- foreign_key=self.foreign_key_target,
- one_to_one=self.one_to_one_target,
+ foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target
)
self.instance.many_to_many.set(self.many_to_many_targets)
@@ -781,15 +875,15 @@ class TestIntegration(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer(self.instance)
expected = {
- 'id': self.instance.pk,
- 'foreign_key': self.foreign_key_target.pk,
- 'one_to_one': self.one_to_one_target.pk,
- 'many_to_many': [item.pk for item in self.many_to_many_targets],
- 'through': []
+ "id": self.instance.pk,
+ "foreign_key": self.foreign_key_target.pk,
+ "one_to_one": self.one_to_one_target.pk,
+ "many_to_many": [item.pk for item in self.many_to_many_targets],
+ "through": [],
}
self.assertEqual(serializer.data, expected)
@@ -797,23 +891,18 @@ class TestIntegration(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
- fields = '__all__'
+ fields = "__all__"
- new_foreign_key = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- new_one_to_one = OneToOneTargetModel.objects.create(
- name='one_to_one'
- )
+ new_foreign_key = ForeignKeyTargetModel.objects.create(name="foreign_key")
+ new_one_to_one = OneToOneTargetModel.objects.create(name="one_to_one")
new_many_to_many = [
- ManyToManyTargetModel.objects.create(
- name='new many_to_many (%d)' % idx
- ) for idx in range(3)
+ ManyToManyTargetModel.objects.create(name="new many_to_many (%d)" % idx)
+ for idx in range(3)
]
data = {
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
+ "foreign_key": new_foreign_key.pk,
+ "one_to_one": new_one_to_one.pk,
+ "many_to_many": [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
@@ -824,20 +913,18 @@ class TestIntegration(TestCase):
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
- assert [
- item.pk for item in instance.many_to_many.all()
- ] == [
+ assert [item.pk for item in instance.many_to_many.all()] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
- 'id': instance.pk,
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- 'through': []
+ "id": instance.pk,
+ "foreign_key": new_foreign_key.pk,
+ "one_to_one": new_one_to_one.pk,
+ "many_to_many": [item.pk for item in new_many_to_many],
+ "through": [],
}
self.assertEqual(serializer.data, expected)
@@ -845,23 +932,18 @@ class TestIntegration(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
- fields = '__all__'
+ fields = "__all__"
- new_foreign_key = ForeignKeyTargetModel.objects.create(
- name='foreign_key'
- )
- new_one_to_one = OneToOneTargetModel.objects.create(
- name='one_to_one'
- )
+ new_foreign_key = ForeignKeyTargetModel.objects.create(name="foreign_key")
+ new_one_to_one = OneToOneTargetModel.objects.create(name="one_to_one")
new_many_to_many = [
- ManyToManyTargetModel.objects.create(
- name='new many_to_many (%d)' % idx
- ) for idx in range(3)
+ ManyToManyTargetModel.objects.create(name="new many_to_many (%d)" % idx)
+ for idx in range(3)
]
data = {
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
+ "foreign_key": new_foreign_key.pk,
+ "one_to_one": new_one_to_one.pk,
+ "many_to_many": [item.pk for item in new_many_to_many],
}
# Serializer should validate okay.
@@ -872,26 +954,25 @@ class TestIntegration(TestCase):
instance = serializer.save()
assert instance.foreign_key.pk == new_foreign_key.pk
assert instance.one_to_one.pk == new_one_to_one.pk
- assert [
- item.pk for item in instance.many_to_many.all()
- ] == [
+ assert [item.pk for item in instance.many_to_many.all()] == [
item.pk for item in new_many_to_many
]
assert list(instance.through.all()) == []
# Representation should be correct.
expected = {
- 'id': self.instance.pk,
- 'foreign_key': new_foreign_key.pk,
- 'one_to_one': new_one_to_one.pk,
- 'many_to_many': [item.pk for item in new_many_to_many],
- 'through': []
+ "id": self.instance.pk,
+ "foreign_key": new_foreign_key.pk,
+ "one_to_one": new_one_to_one.pk,
+ "many_to_many": [item.pk for item in new_many_to_many],
+ "through": [],
}
self.assertEqual(serializer.data, expected)
# Tests for bulk create using `ListSerializer`.
+
class BulkCreateModel(models.Model):
name = models.CharField(max_length=10)
@@ -901,23 +982,27 @@ class TestBulkCreate(TestCase):
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BulkCreateModel
- fields = ('name',)
+ fields = ("name",)
class BulkCreateSerializer(serializers.ListSerializer):
child = BasicModelSerializer()
- data = [{'name': 'a'}, {'name': 'b'}, {'name': 'c'}]
+ data = [{"name": "a"}, {"name": "b"}, {"name": "c"}]
serializer = BulkCreateSerializer(data=data)
assert serializer.is_valid()
# Objects are returned by save().
instances = serializer.save()
assert len(instances) == 3
- assert [item.name for item in instances] == ['a', 'b', 'c']
+ assert [item.name for item in instances] == ["a", "b", "c"]
# Objects have been created in the database.
assert BulkCreateModel.objects.count() == 3
- assert list(BulkCreateModel.objects.values_list('name', flat=True)) == ['a', 'b', 'c']
+ assert list(BulkCreateModel.objects.values_list("name", flat=True)) == [
+ "a",
+ "b",
+ "c",
+ ]
# Serializer returns correct data.
assert serializer.data == data
@@ -932,7 +1017,7 @@ class TestSerializerMetaClass(TestCase):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = MetaClassTestModel
- fields = 'text'
+ fields = "text"
msginitial = "The `fields` option must be a list or tuple"
with self.assertRaisesMessage(TypeError, msginitial):
@@ -942,7 +1027,7 @@ class TestSerializerMetaClass(TestCase):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = MetaClassTestModel
- exclude = 'text'
+ exclude = "text"
msginitial = "The `exclude` option must be a list or tuple"
with self.assertRaisesMessage(TypeError, msginitial):
@@ -952,8 +1037,8 @@ class TestSerializerMetaClass(TestCase):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = MetaClassTestModel
- fields = ('text',)
- exclude = ('text',)
+ fields = ("text",)
+ exclude = ("text",)
msginitial = "Cannot set both 'fields' and 'exclude' options on serializer ExampleSerializer."
with self.assertRaisesMessage(AssertionError, msginitial):
@@ -965,7 +1050,7 @@ class TestSerializerMetaClass(TestCase):
class Meta:
model = MetaClassTestModel
- exclude = ('text',)
+ exclude = ("text",)
expected = (
"Cannot both declare the field 'text' and include it in the "
@@ -983,20 +1068,17 @@ class Issue2704TestCase(TestCase):
class Meta:
model = OneFieldModel
- fields = ('char_field', 'additional_attr')
+ fields = ("char_field", "additional_attr")
- OneFieldModel.objects.create(char_field='abc')
+ OneFieldModel.objects.create(char_field="abc")
qs = OneFieldModel.objects.all()
for o in qs:
- o.additional_attr = '123'
+ o.additional_attr = "123"
serializer = TestSerializer(instance=qs, many=True)
- expected = [{
- 'char_field': 'abc',
- 'additional_attr': '123',
- }]
+ expected = [{"char_field": "abc", "additional_attr": "123"}]
assert serializer.data == expected
@@ -1005,7 +1087,7 @@ class DecimalFieldModel(models.Model):
decimal_field = models.DecimalField(
max_digits=3,
decimal_places=1,
- validators=[MinValueValidator(1), MaxValueValidator(3)]
+ validators=[MinValueValidator(1), MaxValueValidator(3)],
)
@@ -1014,42 +1096,45 @@ class TestDecimalFieldMappings(TestCase):
"""
Test that a `DecimalField` has no `DecimalValidator`.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = DecimalFieldModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer()
- assert len(serializer.fields['decimal_field'].validators) == 2
+ assert len(serializer.fields["decimal_field"].validators) == 2
def test_min_value_is_passed(self):
"""
Test that the `MinValueValidator` is converted to the `min_value`
argument for the field.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = DecimalFieldModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer()
- assert serializer.fields['decimal_field'].min_value == 1
+ assert serializer.fields["decimal_field"].min_value == 1
def test_max_value_is_passed(self):
"""
Test that the `MaxValueValidator` is converted to the `max_value`
argument for the field.
"""
+
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = DecimalFieldModel
- fields = '__all__'
+ fields = "__all__"
serializer = TestSerializer()
- assert serializer.fields['decimal_field'].max_value == 3
+ assert serializer.fields["decimal_field"].max_value == 3
class TestMetaInheritance(TestCase):
@@ -1059,7 +1144,7 @@ class TestMetaInheritance(TestCase):
class Meta:
model = OneFieldModel
- read_only_fields = ('char_field', 'non_model_field')
+ read_only_fields = ("char_field", "non_model_field")
fields = read_only_fields
extra_kwargs = {}
@@ -1067,17 +1152,21 @@ class TestMetaInheritance(TestCase):
class Meta(TestSerializer.Meta):
read_only_fields = ()
- test_expected = dedent("""
+ test_expected = dedent(
+ """
TestSerializer():
char_field = CharField(read_only=True)
non_model_field = CharField()
- """)
+ """
+ )
- child_expected = dedent("""
+ child_expected = dedent(
+ """
ChildSerializer():
char_field = CharField(max_length=100)
non_model_field = CharField()
- """)
+ """
+ )
self.assertEqual(unicode_repr(ChildSerializer()), child_expected)
self.assertEqual(unicode_repr(TestSerializer()), test_expected)
self.assertEqual(unicode_repr(ChildSerializer()), child_expected)
@@ -1088,7 +1177,9 @@ class OneToOneTargetTestModel(models.Model):
class OneToOneSourceTestModel(models.Model):
- target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE)
+ target = models.OneToOneField(
+ OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE
+ )
class TestModelFieldValues(TestCase):
@@ -1096,12 +1187,12 @@ class TestModelFieldValues(TestCase):
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneSourceTestModel
- fields = ('target',)
+ fields = ("target",)
- target = OneToOneTargetTestModel(id=1, text='abc')
+ target = OneToOneTargetTestModel(id=1, text="abc")
source = OneToOneSourceTestModel(target=target)
serializer = ExampleSerializer(source)
- self.assertEqual(serializer.data, {'target': 1})
+ self.assertEqual(serializer.data, {"target": 1})
class TestUniquenessOverride(TestCase):
@@ -1111,17 +1202,17 @@ class TestUniquenessOverride(TestCase):
field_2 = models.IntegerField()
class Meta:
- unique_together = (('field_1', 'field_2'),)
+ unique_together = (("field_1", "field_2"),)
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = TestModel
- fields = '__all__'
- extra_kwargs = {'field_1': {'required': False}}
+ fields = "__all__"
+ extra_kwargs = {"field_1": {"required": False}}
fields = TestSerializer().fields
- self.assertFalse(fields['field_1'].required)
- self.assertTrue(fields['field_2'].required)
+ self.assertFalse(fields["field_1"].required)
+ self.assertTrue(fields["field_2"].required)
class Issue3674Test(TestCase):
@@ -1130,56 +1221,61 @@ class Issue3674Test(TestCase):
title = models.CharField(max_length=64)
class TestChildModel(models.Model):
- parent = models.ForeignKey(TestParentModel, related_name='children', on_delete=models.CASCADE)
+ parent = models.ForeignKey(
+ TestParentModel, related_name="children", on_delete=models.CASCADE
+ )
value = models.CharField(primary_key=True, max_length=64)
class TestChildModelSerializer(serializers.ModelSerializer):
class Meta:
model = TestChildModel
- fields = ('value', 'parent')
+ fields = ("value", "parent")
class TestParentModelSerializer(serializers.ModelSerializer):
class Meta:
model = TestParentModel
- fields = ('id', 'title', 'children')
+ fields = ("id", "title", "children")
- parent_expected = dedent("""
+ parent_expected = dedent(
+ """
TestParentModelSerializer():
id = IntegerField(label='ID', read_only=True)
title = CharField(max_length=64)
children = PrimaryKeyRelatedField(many=True, queryset=TestChildModel.objects.all())
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestParentModelSerializer()), parent_expected)
- child_expected = dedent("""
+ child_expected = dedent(
+ """
TestChildModelSerializer():
value = CharField(max_length=64, validators=[])
parent = PrimaryKeyRelatedField(queryset=TestParentModel.objects.all())
- """)
+ """
+ )
self.assertEqual(unicode_repr(TestChildModelSerializer()), child_expected)
def test_nonID_PK_foreignkey_model_serializer(self):
-
class TestChildModelSerializer(serializers.ModelSerializer):
class Meta:
model = Issue3674ChildModel
- fields = ('value', 'parent')
+ fields = ("value", "parent")
class TestParentModelSerializer(serializers.ModelSerializer):
class Meta:
model = Issue3674ParentModel
- fields = ('id', 'title', 'children')
+ fields = ("id", "title", "children")
- parent = Issue3674ParentModel.objects.create(title='abc')
- child = Issue3674ChildModel.objects.create(value='def', parent=parent)
+ parent = Issue3674ParentModel.objects.create(title="abc")
+ child = Issue3674ChildModel.objects.create(value="def", parent=parent)
parent_serializer = TestParentModelSerializer(parent)
child_serializer = TestChildModelSerializer(child)
- parent_expected = {'children': ['def'], 'id': 1, 'title': 'abc'}
+ parent_expected = {"children": ["def"], "id": 1, "title": "abc"}
self.assertEqual(parent_serializer.data, parent_expected)
- child_expected = {'parent': 1, 'value': 'def'}
+ child_expected = {"parent": 1, "value": "def"}
self.assertEqual(child_serializer.data, child_expected)
@@ -1188,14 +1284,14 @@ class Issue4897TestCase(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = OneFieldModel
- fields = ('char_field',)
+ fields = ("char_field",)
readonly_fields = fields
- obj = OneFieldModel.objects.create(char_field='abc')
+ obj = OneFieldModel.objects.create(char_field="abc")
with pytest.raises(AssertionError) as cm:
TestSerializer(obj).fields
- cm.match(r'readonly_fields')
+ cm.match(r"readonly_fields")
class Test5004UniqueChoiceField(TestCase):
@@ -1203,12 +1299,14 @@ class Test5004UniqueChoiceField(TestCase):
class TestUniqueChoiceSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueChoiceModel
- fields = '__all__'
+ fields = "__all__"
- UniqueChoiceModel.objects.create(name='choice1')
- serializer = TestUniqueChoiceSerializer(data={'name': 'choice1'})
+ UniqueChoiceModel.objects.create(name="choice1")
+ serializer = TestUniqueChoiceSerializer(data={"name": "choice1"})
assert not serializer.is_valid()
- assert serializer.errors == {'name': ['unique choice model with this name already exists.']}
+ assert serializer.errors == {
+ "name": ["unique choice model with this name already exists."]
+ }
class TestFieldSource(TestCase):
@@ -1219,34 +1317,32 @@ class TestFieldSource(TestCase):
Similar to model example from test_serializer.py `test_default_for_multiple_dotted_source` method,
but using RelatedField, rather than CharField.
"""
+
class TestSerializer(serializers.ModelSerializer):
target = serializers.PrimaryKeyRelatedField(
- source='target.target', read_only=True, allow_null=True, default=None
+ source="target.target", read_only=True, allow_null=True, default=None
)
class Meta:
model = NestedForeignKeySource
- fields = ('target', )
+ fields = ("target",)
model = NestedForeignKeySource.objects.create()
- assert TestSerializer(model).data['target'] is None
+ assert TestSerializer(model).data["target"] is None
def test_named_field_source(self):
class TestSerializer(serializers.ModelSerializer):
-
class Meta:
model = RegularFieldsModel
- fields = ('number_field',)
- extra_kwargs = {
- 'number_field': {
- 'source': 'integer_field'
- }
- }
+ fields = ("number_field",)
+ extra_kwargs = {"number_field": {"source": "integer_field"}}
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
number_field = IntegerField(source='integer_field')
- """)
+ """
+ )
self.maxDiff = None
self.assertEqual(unicode_repr(TestSerializer()), expected)
@@ -1261,16 +1357,17 @@ class Issue6110TestModel(models.Model):
class Issue6110ModelSerializer(serializers.ModelSerializer):
class Meta:
model = Issue6110TestModel
- fields = ('name',)
+ fields = ("name",)
class Issue6110Test(TestCase):
-
def test_model_serializer_custom_manager(self):
- instance = Issue6110ModelSerializer().create({'name': 'test_name'})
- self.assertEqual(instance.name, 'test_name')
+ instance = Issue6110ModelSerializer().create({"name": "test_name"})
+ self.assertEqual(instance.name, "test_name")
def test_model_serializer_custom_manager_error_message(self):
- msginitial = ('Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`.')
+ msginitial = (
+ "Got a `TypeError` when calling `Issue6110TestModel.all_objects.create()`."
+ )
with self.assertRaisesMessage(TypeError, msginitial):
- Issue6110ModelSerializer().create({'wrong_param': 'wrong_param'})
+ Issue6110ModelSerializer().create({"wrong_param": "wrong_param"})
diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py
index 2ddd37ebb..eafb1c5ea 100644
--- a/tests/test_multitable_inheritance.py
+++ b/tests/test_multitable_inheritance.py
@@ -25,45 +25,41 @@ class AssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildModel
- fields = '__all__'
+ fields = "__all__"
class AssociatedModelSerializer(serializers.ModelSerializer):
class Meta:
model = AssociatedModel
- fields = '__all__'
+ fields = "__all__"
# Tests
class InheritedModelSerializationTests(TestCase):
-
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
- child = ChildModel(name1='parent name', name2='child name')
+ child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child)
- assert set(serializer.data) == {'name1', 'name2', 'id'}
+ assert set(serializer.data) == {"name1", "name2", "id"}
def test_onetoone_primary_key_model_fields_as_expected(self):
"""
Assert that a model with a onetoone field that is the primary key is
not treated like a derived model
"""
- parent = ParentModel.objects.create(name1='parent name')
- associate = AssociatedModel.objects.create(name='hello', ref=parent)
+ parent = ParentModel.objects.create(name1="parent name")
+ associate = AssociatedModel.objects.create(name="hello", ref=parent)
serializer = AssociatedModelSerializer(associate)
- assert set(serializer.data) == {'name', 'ref'}
+ assert set(serializer.data) == {"name", "ref"}
def test_data_is_valid_without_parent_ptr(self):
"""
Assert that the pointer to the parent table is not a required field
for input data
"""
- data = {
- 'name1': 'parent name',
- 'name2': 'child name',
- }
+ data = {"name1": "parent name", "name2": "child name"}
serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True
diff --git a/tests/test_negotiation.py b/tests/test_negotiation.py
index 7ce3f92a9..bd5a984dc 100644
--- a/tests/test_negotiation.py
+++ b/tests/test_negotiation.py
@@ -4,32 +4,31 @@ import pytest
from django.http import Http404
from django.test import TestCase
-from rest_framework.negotiation import (
- BaseContentNegotiation, DefaultContentNegotiation
-)
+from rest_framework.negotiation import BaseContentNegotiation, DefaultContentNegotiation
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
from rest_framework.utils.mediatypes import _MediaType
+
factory = APIRequestFactory()
class MockOpenAPIRenderer(BaseRenderer):
- media_type = 'application/openapi+json;version=2.0'
- format = 'swagger'
+ media_type = "application/openapi+json;version=2.0"
+ format = "swagger"
class MockJSONRenderer(BaseRenderer):
- media_type = 'application/json'
+ media_type = "application/json"
class MockHTMLRenderer(BaseRenderer):
- media_type = 'text/html'
+ media_type = "text/html"
class NoCharsetSpecifiedRenderer(BaseRenderer):
- media_type = 'my/media'
+ media_type = "my/media"
class TestAcceptedMediaType(TestCase):
@@ -41,54 +40,56 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
- request = Request(factory.get('/'))
+ request = Request(factory.get("/"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json'
+ assert accepted_media_type == "application/json"
def test_client_underspecifies_accept_use_renderer(self):
- request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
+ request = Request(factory.get("/", HTTP_ACCEPT="*/*"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json'
+ assert accepted_media_type == "application/json"
def test_client_overspecifies_accept_use_client(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
+ request = Request(factory.get("/", HTTP_ACCEPT="application/json; indent=8"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/json; indent=8'
+ assert accepted_media_type == "application/json; indent=8"
def test_client_specifies_parameter(self):
- request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
+ request = Request(
+ factory.get("/", HTTP_ACCEPT="application/openapi+json;version=2.0")
+ )
accepted_renderer, accepted_media_type = self.select_renderer(request)
- assert accepted_media_type == 'application/openapi+json;version=2.0'
- assert accepted_renderer.format == 'swagger'
+ assert accepted_media_type == "application/openapi+json;version=2.0"
+ assert accepted_renderer.format == "swagger"
def test_match_is_false_if_main_types_not_match(self):
- mediatype = _MediaType('test_1')
- anoter_mediatype = _MediaType('test_2')
+ mediatype = _MediaType("test_1")
+ anoter_mediatype = _MediaType("test_2")
assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self):
- mediatype = _MediaType(';test_param=foo')
- another_mediatype = _MediaType(';test_param=bar')
+ mediatype = _MediaType(";test_param=foo")
+ another_mediatype = _MediaType(";test_param=bar")
assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self):
- mediatype = _MediaType('test/*')
+ mediatype = _MediaType("test/*")
assert mediatype.precedence == 1
def test_mediatype_string_representation(self):
- mediatype = _MediaType('test/*; foo=bar')
- assert str(mediatype) == 'test/*; foo=bar'
+ mediatype = _MediaType("test/*; foo=bar")
+ assert str(mediatype) == "test/*; foo=bar"
def test_raise_error_if_no_suitable_renderers_found(self):
class MockRenderer(object):
- format = 'xml'
+ format = "xml"
+
renderers = [MockRenderer()]
with pytest.raises(Http404):
- self.negotiator.filter_renderers(renderers, format='json')
+ self.negotiator.filter_renderers(renderers, format="json")
class BaseContentNegotiationTests(TestCase):
-
def setUp(self):
self.negotiator = BaseContentNegotiation()
diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py
index 789c7fcb9..0b881aba7 100644
--- a/tests/test_one_to_one_with_inheritance.py
+++ b/tests/test_one_to_one_with_inheritance.py
@@ -3,9 +3,9 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
+# Models
from rest_framework import serializers
from tests.models import RESTFrameworkModel
-# Models
from tests.test_multitable_inheritance import ChildModel
@@ -19,25 +19,24 @@ class ChildAssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildModel
- fields = ['id', 'name1', 'name2', 'childassociatedmodel']
+ fields = ["id", "name1", "name2", "childassociatedmodel"]
class ChildAssociatedModelSerializer(serializers.ModelSerializer):
-
class Meta:
model = ChildAssociatedModel
- fields = ['id', 'child_name']
+ fields = ["id", "child_name"]
# Tests
class InheritedModelSerializationTests(TestCase):
-
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
- child = ChildModel(name1='parent name', name2='child name')
+ child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child)
- self.assertEqual(set(serializer.data),
- {'name1', 'name2', 'id', 'childassociatedmodel'})
+ self.assertEqual(
+ set(serializer.data), {"name1", "name2", "id", "childassociatedmodel"}
+ )
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index 6d940fe2b..edd93f61e 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -8,12 +8,18 @@ from django.test import TestCase
from django.utils import six
from rest_framework import (
- exceptions, filters, generics, pagination, serializers, status
+ exceptions,
+ filters,
+ generics,
+ pagination,
+ serializers,
+ status,
)
from rest_framework.pagination import PAGE_BREAK, PageLink
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
@@ -33,39 +39,39 @@ class TestPaginationIntegration:
class BasicPagination(pagination.PageNumberPagination):
page_size = 5
- page_size_query_param = 'page_size'
+ page_size_query_param = "page_size"
max_page_size = 20
self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer,
queryset=range(1, 101),
filter_backends=[EvenItemsOnly],
- pagination_class=BasicPagination
+ pagination_class=BasicPagination,
)
def test_filtered_items_are_paginated(self):
- request = factory.get('/', {'page': 2})
+ request = factory.get("/", {"page": 2})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [12, 14, 16, 18, 20],
- 'previous': 'http://testserver/',
- 'next': 'http://testserver/?page=3',
- 'count': 50
+ "results": [12, 14, 16, 18, 20],
+ "previous": "http://testserver/",
+ "next": "http://testserver/?page=3",
+ "count": 50,
}
def test_setting_page_size(self):
"""
When 'paginate_by_param' is set, the client may choose a page size.
"""
- request = factory.get('/', {'page_size': 10})
+ request = factory.get("/", {"page_size": 10})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
- 'previous': None,
- 'next': 'http://testserver/?page=2&page_size=10',
- 'count': 50
+ "results": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
+ "previous": None,
+ "next": "http://testserver/?page=2&page_size=10",
+ "count": 50,
}
def test_setting_page_size_over_maximum(self):
@@ -73,70 +79,84 @@ class TestPaginationIntegration:
When page_size parameter exceeds maximum allowable,
then it should be capped to the maximum.
"""
- request = factory.get('/', {'page_size': 1000})
+ request = factory.get("/", {"page_size": 1000})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [
- 2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
- 22, 24, 26, 28, 30, 32, 34, 36, 38, 40
+ "results": [
+ 2,
+ 4,
+ 6,
+ 8,
+ 10,
+ 12,
+ 14,
+ 16,
+ 18,
+ 20,
+ 22,
+ 24,
+ 26,
+ 28,
+ 30,
+ 32,
+ 34,
+ 36,
+ 38,
+ 40,
],
- 'previous': None,
- 'next': 'http://testserver/?page=2&page_size=1000',
- 'count': 50
+ "previous": None,
+ "next": "http://testserver/?page=2&page_size=1000",
+ "count": 50,
}
def test_setting_page_size_to_zero(self):
"""
When page_size parameter is invalid it should return to the default.
"""
- request = factory.get('/', {'page_size': 0})
+ request = factory.get("/", {"page_size": 0})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [2, 4, 6, 8, 10],
- 'previous': None,
- 'next': 'http://testserver/?page=2&page_size=0',
- 'count': 50
+ "results": [2, 4, 6, 8, 10],
+ "previous": None,
+ "next": "http://testserver/?page=2&page_size=0",
+ "count": 50,
}
def test_additional_query_params_are_preserved(self):
- request = factory.get('/', {'page': 2, 'filter': 'even'})
+ request = factory.get("/", {"page": 2, "filter": "even"})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [12, 14, 16, 18, 20],
- 'previous': 'http://testserver/?filter=even',
- 'next': 'http://testserver/?filter=even&page=3',
- 'count': 50
+ "results": [12, 14, 16, 18, 20],
+ "previous": "http://testserver/?filter=even",
+ "next": "http://testserver/?filter=even&page=3",
+ "count": 50,
}
def test_empty_query_params_are_preserved(self):
- request = factory.get('/', {'page': 2, 'filter': ''})
+ request = factory.get("/", {"page": 2, "filter": ""})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
- 'results': [12, 14, 16, 18, 20],
- 'previous': 'http://testserver/?filter=',
- 'next': 'http://testserver/?filter=&page=3',
- 'count': 50
+ "results": [12, 14, 16, 18, 20],
+ "previous": "http://testserver/?filter=",
+ "next": "http://testserver/?filter=&page=3",
+ "count": 50,
}
def test_404_not_found_for_zero_page(self):
- request = factory.get('/', {'page': '0'})
+ request = factory.get("/", {"page": "0"})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
- assert response.data == {
- 'detail': 'Invalid page.'
- }
+ assert response.data == {"detail": "Invalid page."}
def test_404_not_found_for_invalid_page(self):
- request = factory.get('/', {'page': 'invalid'})
+ request = factory.get("/", {"page": "invalid"})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
- assert response.data == {
- 'detail': 'Invalid page.'
- }
+ assert response.data == {"detail": "Invalid page."}
class TestPaginationDisabledIntegration:
@@ -152,11 +172,11 @@ class TestPaginationDisabledIntegration:
self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer,
queryset=range(1, 101),
- pagination_class=None
+ pagination_class=None,
)
def test_unpaginated_list(self):
- request = factory.get('/', {'page': 2})
+ request = factory.get("/", {"page": 2})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == list(range(1, 101))
@@ -185,81 +205,81 @@ class TestPageNumberPagination:
return self.pagination.get_html_context()
def test_no_page_number(self):
- request = Request(factory.get('/'))
+ request = Request(factory.get("/"))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5]
assert content == {
- 'results': [1, 2, 3, 4, 5],
- 'previous': None,
- 'next': 'http://testserver/?page=2',
- 'count': 100
+ "results": [1, 2, 3, 4, 5],
+ "previous": None,
+ "next": "http://testserver/?page=2",
+ "count": 100,
}
assert context == {
- 'previous_url': None,
- 'next_url': 'http://testserver/?page=2',
- 'page_links': [
- PageLink('http://testserver/', 1, True, False),
- PageLink('http://testserver/?page=2', 2, False, False),
- PageLink('http://testserver/?page=3', 3, False, False),
+ "previous_url": None,
+ "next_url": "http://testserver/?page=2",
+ "page_links": [
+ PageLink("http://testserver/", 1, True, False),
+ PageLink("http://testserver/?page=2", 2, False, False),
+ PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?page=20', 20, False, False),
- ]
+ PageLink("http://testserver/?page=20", 20, False, False),
+ ],
}
assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
def test_second_page(self):
- request = Request(factory.get('/', {'page': 2}))
+ request = Request(factory.get("/", {"page": 2}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10]
assert content == {
- 'results': [6, 7, 8, 9, 10],
- 'previous': 'http://testserver/',
- 'next': 'http://testserver/?page=3',
- 'count': 100
+ "results": [6, 7, 8, 9, 10],
+ "previous": "http://testserver/",
+ "next": "http://testserver/?page=3",
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/',
- 'next_url': 'http://testserver/?page=3',
- 'page_links': [
- PageLink('http://testserver/', 1, False, False),
- PageLink('http://testserver/?page=2', 2, True, False),
- PageLink('http://testserver/?page=3', 3, False, False),
+ "previous_url": "http://testserver/",
+ "next_url": "http://testserver/?page=3",
+ "page_links": [
+ PageLink("http://testserver/", 1, False, False),
+ PageLink("http://testserver/?page=2", 2, True, False),
+ PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?page=20', 20, False, False),
- ]
+ PageLink("http://testserver/?page=20", 20, False, False),
+ ],
}
def test_last_page(self):
- request = Request(factory.get('/', {'page': 'last'}))
+ request = Request(factory.get("/", {"page": "last"}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100]
assert content == {
- 'results': [96, 97, 98, 99, 100],
- 'previous': 'http://testserver/?page=19',
- 'next': None,
- 'count': 100
+ "results": [96, 97, 98, 99, 100],
+ "previous": "http://testserver/?page=19",
+ "next": None,
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/?page=19',
- 'next_url': None,
- 'page_links': [
- PageLink('http://testserver/', 1, False, False),
+ "previous_url": "http://testserver/?page=19",
+ "next_url": None,
+ "page_links": [
+ PageLink("http://testserver/", 1, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?page=18', 18, False, False),
- PageLink('http://testserver/?page=19', 19, False, False),
- PageLink('http://testserver/?page=20', 20, True, False),
- ]
+ PageLink("http://testserver/?page=18", 18, False, False),
+ PageLink("http://testserver/?page=19", 19, False, False),
+ PageLink("http://testserver/?page=20", 20, True, False),
+ ],
}
def test_invalid_page(self):
- request = Request(factory.get('/', {'page': 'invalid'}))
+ request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request)
@@ -295,29 +315,22 @@ class TestPageNumberPaginationOverride:
return self.pagination.get_html_context()
def test_no_page_number(self):
- request = Request(factory.get('/'))
+ request = Request(factory.get("/"))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1]
- assert content == {
- 'results': [1, ],
- 'previous': None,
- 'next': None,
- 'count': 1
- }
+ assert content == {"results": [1], "previous": None, "next": None, "count": 1}
assert context == {
- 'previous_url': None,
- 'next_url': None,
- 'page_links': [
- PageLink('http://testserver/', 1, True, False),
- ]
+ "previous_url": None,
+ "next_url": None,
+ "page_links": [PageLink("http://testserver/", 1, True, False)],
}
assert not self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
def test_invalid_page(self):
- request = Request(factory.get('/', {'page': 'invalid'}))
+ request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request)
@@ -346,27 +359,27 @@ class TestLimitOffset:
return self.pagination.get_html_context()
def test_no_offset(self):
- request = Request(factory.get('/', {'limit': 5}))
+ request = Request(factory.get("/", {"limit": 5}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5]
assert content == {
- 'results': [1, 2, 3, 4, 5],
- 'previous': None,
- 'next': 'http://testserver/?limit=5&offset=5',
- 'count': 100
+ "results": [1, 2, 3, 4, 5],
+ "previous": None,
+ "next": "http://testserver/?limit=5&offset=5",
+ "count": 100,
}
assert context == {
- 'previous_url': None,
- 'next_url': 'http://testserver/?limit=5&offset=5',
- 'page_links': [
- PageLink('http://testserver/?limit=5', 1, True, False),
- PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
- PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ "previous_url": None,
+ "next_url": "http://testserver/?limit=5&offset=5",
+ "page_links": [
+ PageLink("http://testserver/?limit=5", 1, True, False),
+ PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
+ PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
- ]
+ PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
+ ],
}
assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
@@ -374,7 +387,8 @@ class TestLimitOffset:
def test_pagination_not_applied_if_limit_or_default_limit_not_set(self):
class MockPagination(pagination.LimitOffsetPagination):
default_limit = None
- request = Request(factory.get('/'))
+
+ request = Request(factory.get("/"))
queryset = MockPagination().paginate_queryset(self.queryset, request)
assert queryset is None
@@ -384,104 +398,104 @@ class TestLimitOffset:
* The first page should still be offset zero.
* We may end up displaying an extra page in the pagination control.
"""
- request = Request(factory.get('/', {'limit': 5, 'offset': 1}))
+ request = Request(factory.get("/", {"limit": 5, "offset": 1}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [2, 3, 4, 5, 6]
assert content == {
- 'results': [2, 3, 4, 5, 6],
- 'previous': 'http://testserver/?limit=5',
- 'next': 'http://testserver/?limit=5&offset=6',
- 'count': 100
+ "results": [2, 3, 4, 5, 6],
+ "previous": "http://testserver/?limit=5",
+ "next": "http://testserver/?limit=5&offset=6",
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/?limit=5',
- 'next_url': 'http://testserver/?limit=5&offset=6',
- 'page_links': [
- PageLink('http://testserver/?limit=5', 1, False, False),
- PageLink('http://testserver/?limit=5&offset=1', 2, True, False),
- PageLink('http://testserver/?limit=5&offset=6', 3, False, False),
+ "previous_url": "http://testserver/?limit=5",
+ "next_url": "http://testserver/?limit=5&offset=6",
+ "page_links": [
+ PageLink("http://testserver/?limit=5", 1, False, False),
+ PageLink("http://testserver/?limit=5&offset=1", 2, True, False),
+ PageLink("http://testserver/?limit=5&offset=6", 3, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?limit=5&offset=96', 21, False, False),
- ]
+ PageLink("http://testserver/?limit=5&offset=96", 21, False, False),
+ ],
}
def test_first_offset(self):
- request = Request(factory.get('/', {'limit': 5, 'offset': 5}))
+ request = Request(factory.get("/", {"limit": 5, "offset": 5}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10]
assert content == {
- 'results': [6, 7, 8, 9, 10],
- 'previous': 'http://testserver/?limit=5',
- 'next': 'http://testserver/?limit=5&offset=10',
- 'count': 100
+ "results": [6, 7, 8, 9, 10],
+ "previous": "http://testserver/?limit=5",
+ "next": "http://testserver/?limit=5&offset=10",
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/?limit=5',
- 'next_url': 'http://testserver/?limit=5&offset=10',
- 'page_links': [
- PageLink('http://testserver/?limit=5', 1, False, False),
- PageLink('http://testserver/?limit=5&offset=5', 2, True, False),
- PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
+ "previous_url": "http://testserver/?limit=5",
+ "next_url": "http://testserver/?limit=5&offset=10",
+ "page_links": [
+ PageLink("http://testserver/?limit=5", 1, False, False),
+ PageLink("http://testserver/?limit=5&offset=5", 2, True, False),
+ PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
- ]
+ PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
+ ],
}
def test_middle_offset(self):
- request = Request(factory.get('/', {'limit': 5, 'offset': 10}))
+ request = Request(factory.get("/", {"limit": 5, "offset": 10}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [11, 12, 13, 14, 15]
assert content == {
- 'results': [11, 12, 13, 14, 15],
- 'previous': 'http://testserver/?limit=5&offset=5',
- 'next': 'http://testserver/?limit=5&offset=15',
- 'count': 100
+ "results": [11, 12, 13, 14, 15],
+ "previous": "http://testserver/?limit=5&offset=5",
+ "next": "http://testserver/?limit=5&offset=15",
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/?limit=5&offset=5',
- 'next_url': 'http://testserver/?limit=5&offset=15',
- 'page_links': [
- PageLink('http://testserver/?limit=5', 1, False, False),
- PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
- PageLink('http://testserver/?limit=5&offset=10', 3, True, False),
- PageLink('http://testserver/?limit=5&offset=15', 4, False, False),
+ "previous_url": "http://testserver/?limit=5&offset=5",
+ "next_url": "http://testserver/?limit=5&offset=15",
+ "page_links": [
+ PageLink("http://testserver/?limit=5", 1, False, False),
+ PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
+ PageLink("http://testserver/?limit=5&offset=10", 3, True, False),
+ PageLink("http://testserver/?limit=5&offset=15", 4, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
- ]
+ PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
+ ],
}
def test_ending_offset(self):
- request = Request(factory.get('/', {'limit': 5, 'offset': 95}))
+ request = Request(factory.get("/", {"limit": 5, "offset": 95}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100]
assert content == {
- 'results': [96, 97, 98, 99, 100],
- 'previous': 'http://testserver/?limit=5&offset=90',
- 'next': None,
- 'count': 100
+ "results": [96, 97, 98, 99, 100],
+ "previous": "http://testserver/?limit=5&offset=90",
+ "next": None,
+ "count": 100,
}
assert context == {
- 'previous_url': 'http://testserver/?limit=5&offset=90',
- 'next_url': None,
- 'page_links': [
- PageLink('http://testserver/?limit=5', 1, False, False),
+ "previous_url": "http://testserver/?limit=5&offset=90",
+ "next_url": None,
+ "page_links": [
+ PageLink("http://testserver/?limit=5", 1, False, False),
PAGE_BREAK,
- PageLink('http://testserver/?limit=5&offset=85', 18, False, False),
- PageLink('http://testserver/?limit=5&offset=90', 19, False, False),
- PageLink('http://testserver/?limit=5&offset=95', 20, True, False),
- ]
+ PageLink("http://testserver/?limit=5&offset=85", 18, False, False),
+ PageLink("http://testserver/?limit=5&offset=90", 19, False, False),
+ PageLink("http://testserver/?limit=5&offset=95", 20, True, False),
+ ],
}
def test_erronous_offset(self):
- request = Request(factory.get('/', {'limit': 5, 'offset': 1000}))
+ request = Request(factory.get("/", {"limit": 5, "offset": 1000}))
queryset = self.paginate_queryset(request)
self.get_paginated_content(queryset)
self.get_html_context()
@@ -490,7 +504,7 @@ class TestLimitOffset:
"""
An invalid offset query param should be treated as 0.
"""
- request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'}))
+ request = Request(factory.get("/", {"limit": 5, "offset": "invalid"}))
queryset = self.paginate_queryset(request)
assert queryset == [1, 2, 3, 4, 5]
@@ -498,27 +512,31 @@ class TestLimitOffset:
"""
An invalid limit query param should be ignored in favor of the default.
"""
- request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0}))
+ request = Request(factory.get("/", {"limit": "invalid", "offset": 0}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit
- next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset)
+ next_url = "http://testserver/?limit={0}&offset={1}".format(
+ next_limit, next_offset
+ )
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- assert content.get('next') == next_url
+ assert content.get("next") == next_url
def test_zero_limit(self):
"""
An zero limit query param should be ignored in favor of the default.
"""
- request = Request(factory.get('/', {'limit': 0, 'offset': 0}))
+ request = Request(factory.get("/", {"limit": 0, "offset": 0}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit
- next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset)
+ next_url = "http://testserver/?limit={0}&offset={1}".format(
+ next_limit, next_offset
+ )
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
- assert content.get('next') == next_url
+ assert content.get("next") == next_url
def test_max_limit(self):
"""
@@ -526,47 +544,46 @@ class TestLimitOffset:
requested limit is greater than the max_limit
"""
offset = 50
- request = Request(factory.get('/', {'limit': '11235', 'offset': offset}))
+ request = Request(factory.get("/", {"limit": "11235", "offset": offset}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
max_limit = self.pagination.max_limit
next_offset = offset + max_limit
prev_offset = offset - max_limit
- base_url = 'http://testserver/?limit={0}'.format(max_limit)
- next_url = base_url + '&offset={0}'.format(next_offset)
- prev_url = base_url + '&offset={0}'.format(prev_offset)
+ base_url = "http://testserver/?limit={0}".format(max_limit)
+ next_url = base_url + "&offset={0}".format(next_offset)
+ prev_url = base_url + "&offset={0}".format(prev_offset)
assert queryset == list(range(51, 66))
- assert content.get('next') == next_url
- assert content.get('previous') == prev_url
+ assert content.get("next") == next_url
+ assert content.get("previous") == prev_url
class CursorPaginationTestsMixin:
-
def test_invalid_cursor(self):
- request = Request(factory.get('/', {'cursor': '123'}))
+ request = Request(factory.get("/", {"cursor": "123"}))
with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request)
def test_use_with_ordering_filter(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
- ordering_fields = ['username', 'created']
- ordering = 'created'
+ ordering_fields = ["username", "created"]
+ ordering = "created"
- request = Request(factory.get('/', {'ordering': 'username'}))
+ request = Request(factory.get("/", {"ordering": "username"}))
ordering = self.pagination.get_ordering(request, [], MockView())
- assert ordering == ('username',)
+ assert ordering == ("username",)
- request = Request(factory.get('/', {'ordering': '-username'}))
+ request = Request(factory.get("/", {"ordering": "-username"}))
ordering = self.pagination.get_ordering(request, [], MockView())
- assert ordering == ('-username',)
+ assert ordering == ("-username",)
- request = Request(factory.get('/', {'ordering': 'invalid'}))
+ request = Request(factory.get("/", {"ordering": "invalid"}))
ordering = self.pagination.get_ordering(request, [], MockView())
- assert ordering == ('created',)
+ assert ordering == ("created",)
def test_cursor_pagination(self):
- (previous, current, next, previous_url, next_url) = self.get_pages('/')
+ (previous, current, next, previous_url, next_url) = self.get_pages("/")
assert previous is None
assert current == [1, 1, 1, 1, 1]
@@ -635,7 +652,9 @@ class CursorPaginationTestsMixin:
assert isinstance(self.pagination.to_html(), six.text_type)
def test_cursor_pagination_with_page_size(self):
- (previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=20')
+ (previous, current, next, previous_url, next_url) = self.get_pages(
+ "/?page_size=20"
+ )
assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@@ -647,7 +666,9 @@ class CursorPaginationTestsMixin:
assert next is None
def test_cursor_pagination_with_page_size_over_limit(self):
- (previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=30')
+ (previous, current, next, previous_url, next_url) = self.get_pages(
+ "/?page_size=30"
+ )
assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@@ -659,7 +680,9 @@ class CursorPaginationTestsMixin:
assert next is None
def test_cursor_pagination_with_page_size_zero(self):
- (previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=0')
+ (previous, current, next, previous_url, next_url) = self.get_pages(
+ "/?page_size=0"
+ )
assert previous is None
assert current == [1, 1, 1, 1, 1]
@@ -726,7 +749,9 @@ class CursorPaginationTestsMixin:
assert next == [1, 2, 3, 4, 4]
def test_cursor_pagination_with_page_size_negative(self):
- (previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=-5')
+ (previous, current, next, previous_url, next_url) = self.get_pages(
+ "/?page_size=-5"
+ )
assert previous is None
assert current == [1, 1, 1, 1, 1]
@@ -809,19 +834,17 @@ class TestCursorPagination(CursorPaginationTestsMixin):
def filter(self, created__gt=None, created__lt=None):
if created__gt is not None:
- return MockQuerySet([
- item for item in self.items
- if item.created > int(created__gt)
- ])
+ return MockQuerySet(
+ [item for item in self.items if item.created > int(created__gt)]
+ )
assert created__lt is not None
- return MockQuerySet([
- item for item in self.items
- if item.created < int(created__lt)
- ])
+ return MockQuerySet(
+ [item for item in self.items if item.created < int(created__lt)]
+ )
def order_by(self, *ordering):
- if ordering[0].startswith('-'):
+ if ordering[0].startswith("-"):
return MockQuerySet(list(reversed(self.items)))
return self
@@ -830,21 +853,48 @@ class TestCursorPagination(CursorPaginationTestsMixin):
class ExamplePagination(pagination.CursorPagination):
page_size = 5
- page_size_query_param = 'page_size'
+ page_size_query_param = "page_size"
max_page_size = 20
- ordering = 'created'
+ ordering = "created"
self.pagination = ExamplePagination()
- self.queryset = MockQuerySet([
- MockObject(idx) for idx in [
- 1, 1, 1, 1, 1,
- 1, 2, 3, 4, 4,
- 4, 4, 5, 6, 7,
- 7, 7, 7, 7, 7,
- 7, 7, 7, 8, 9,
- 9, 9, 9, 9, 9
+ self.queryset = MockQuerySet(
+ [
+ MockObject(idx)
+ for idx in [
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 3,
+ 4,
+ 4,
+ 4,
+ 4,
+ 5,
+ 6,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 8,
+ 9,
+ 9,
+ 9,
+ 9,
+ 9,
+ 9,
+ ]
]
- ])
+ )
def get_pages(self, url):
"""
@@ -888,18 +938,42 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
def setUp(self):
class ExamplePagination(pagination.CursorPagination):
page_size = 5
- page_size_query_param = 'page_size'
+ page_size_query_param = "page_size"
max_page_size = 20
- ordering = 'created'
+ ordering = "created"
self.pagination = ExamplePagination()
data = [
- 1, 1, 1, 1, 1,
- 1, 2, 3, 4, 4,
- 4, 4, 5, 6, 7,
- 7, 7, 7, 7, 7,
- 7, 7, 7, 8, 9,
- 9, 9, 9, 9, 9
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 1,
+ 2,
+ 3,
+ 4,
+ 4,
+ 4,
+ 4,
+ 5,
+ 6,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 7,
+ 8,
+ 9,
+ 9,
+ 9,
+ 9,
+ 9,
+ 9,
]
for idx in data:
CursorPaginationModel.objects.create(created=idx)
@@ -914,7 +988,7 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
"""
request = Request(factory.get(url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
- current = [item['created'] for item in queryset]
+ current = [item["created"] for item in queryset]
next_url = self.pagination.get_next_link()
previous_url = self.pagination.get_previous_link()
@@ -922,14 +996,14 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
if next_url is not None:
request = Request(factory.get(next_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
- next = [item['created'] for item in queryset]
+ next = [item["created"] for item in queryset]
else:
next = None
if previous_url is not None:
request = Request(factory.get(previous_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
- previous = [item['created'] for item in queryset]
+ previous = [item["created"] for item in queryset]
else:
previous = None
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index e793948e3..0212f0fa4 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -7,7 +7,8 @@ import math
import pytest
from django import forms
from django.core.files.uploadhandler import (
- MemoryFileUploadHandler, TemporaryFileUploadHandler
+ MemoryFileUploadHandler,
+ TemporaryFileUploadHandler,
)
from django.http.request import RawPostDataException
from django.test import TestCase
@@ -15,7 +16,10 @@ from django.utils.six import StringIO
from rest_framework.exceptions import ParseError
from rest_framework.parsers import (
- FileUploadParser, FormParser, JSONParser, MultiPartParser
+ FileUploadParser,
+ FormParser,
+ JSONParser,
+ MultiPartParser,
)
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
@@ -44,16 +48,15 @@ class TestFileUploadParser(TestCase):
def setUp(self):
class MockRequest(object):
pass
- self.stream = io.BytesIO(
- "Test text file".encode('utf-8')
- )
+
+ self.stream = io.BytesIO("Test text file".encode("utf-8"))
request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = {
- 'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
- 'HTTP_CONTENT_LENGTH': 14,
+ "HTTP_CONTENT_DISPOSITION": "Content-Disposition: inline; filename=file.txt",
+ "HTTP_CONTENT_LENGTH": 14,
}
- self.parser_context = {'request': request, 'kwargs': {}}
+ self.parser_context = {"request": request, "kwargs": {}}
def test_parse(self):
"""
@@ -62,7 +65,7 @@ class TestFileUploadParser(TestCase):
parser = FileUploadParser()
self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context)
- file_obj = data_and_files.files['file']
+ file_obj = data_and_files.files["file"]
assert file_obj.size == 14
def test_parse_missing_filename(self):
@@ -71,10 +74,13 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
- assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
+ assert (
+ str(excinfo.value)
+ == "Missing filename. Request should include a Content-Disposition header with a filename parameter."
+ )
def test_parse_missing_filename_multiple_upload_handlers(self):
"""
@@ -83,14 +89,17 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
- self.parser_context['request'].upload_handlers = (
+ self.parser_context["request"].upload_handlers = (
+ MemoryFileUploadHandler(),
MemoryFileUploadHandler(),
- MemoryFileUploadHandler()
)
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
- assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
+ assert (
+ str(excinfo.value)
+ == "Missing filename. Request should include a Content-Disposition header with a filename parameter."
+ )
def test_parse_missing_filename_large_file(self):
"""
@@ -98,54 +107,59 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
- self.parser_context['request'].upload_handlers = (
- TemporaryFileUploadHandler(),
- )
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
+ self.parser_context["request"].upload_handlers = (TemporaryFileUploadHandler(),)
+ self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
- assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
+ assert (
+ str(excinfo.value)
+ == "Missing filename. Request should include a Content-Disposition header with a filename parameter."
+ )
def test_get_filename(self):
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'file.txt'
+ assert filename == "file.txt"
def test_get_encoded_filename(self):
parser = FileUploadParser()
- self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
+ self.__replace_content_disposition("inline; filename*=utf-8''ÀĥƦ.txt")
filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
+ assert filename == "ÀĥƦ.txt"
- self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
+ self.__replace_content_disposition(
+ "inline; filename=fallback.txt; filename*=utf-8''ÀĥƦ.txt"
+ )
filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
+ assert filename == "ÀĥƦ.txt"
- self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
+ self.__replace_content_disposition(
+ "inline; filename=fallback.txt; filename*=utf-8'en-us'ÀĥƦ.txt"
+ )
filename = parser.get_filename(self.stream, None, self.parser_context)
- assert filename == 'ÀĥƦ.txt'
+ assert filename == "ÀĥƦ.txt"
def __replace_content_disposition(self, disposition):
- self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
+ self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = disposition
class TestJSONParser(TestCase):
def bytes(self, value):
- return io.BytesIO(value.encode('utf-8'))
+ return io.BytesIO(value.encode("utf-8"))
def test_float_strictness(self):
parser = JSONParser()
# Default to strict
- for value in ['Infinity', '-Infinity', 'NaN']:
+ for value in ["Infinity", "-Infinity", "NaN"]:
with pytest.raises(ParseError):
parser.parse(self.bytes(value))
parser.strict = False
- assert parser.parse(self.bytes('Infinity')) == float('inf')
- assert parser.parse(self.bytes('-Infinity')) == float('-inf')
- assert math.isnan(parser.parse(self.bytes('NaN')))
+ assert parser.parse(self.bytes("Infinity")) == float("inf")
+ assert parser.parse(self.bytes("-Infinity")) == float("-inf")
+ assert math.isnan(parser.parse(self.bytes("NaN")))
class TestPOSTAccessed(TestCase):
@@ -153,28 +167,28 @@ class TestPOSTAccessed(TestCase):
self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self):
- django_request = self.factory.post('/', {'foo': 'bar'})
+ django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
- assert request.POST == {'foo': ['bar']}
- assert request.data == {'foo': ['bar']}
+ assert request.POST == {"foo": ["bar"]}
+ assert request.data == {"foo": ["bar"]}
def test_post_accessed_in_post_method_with_json_parser(self):
- django_request = self.factory.post('/', {'foo': 'bar'})
+ django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[JSONParser()])
django_request.POST
assert request.POST == {}
assert request.data == {}
def test_post_accessed_in_put_method(self):
- django_request = self.factory.put('/', {'foo': 'bar'})
+ django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
- assert request.POST == {'foo': ['bar']}
- assert request.data == {'foo': ['bar']}
+ assert request.POST == {"foo": ["bar"]}
+ assert request.data == {"foo": ["bar"]}
def test_request_read_before_parsing(self):
- django_request = self.factory.put('/', {'foo': 'bar'})
+ django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read()
with pytest.raises(RawPostDataException):
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
index 2fabdfa05..8cd17b129 100644
--- a/tests/test_permissions.py
+++ b/tests/test_permissions.py
@@ -12,8 +12,14 @@ from django.test import TestCase
from django.urls import ResolverMatch
from rest_framework import (
- HTTP_HEADER_ENCODING, RemovedInDRF310Warning, authentication, generics,
- permissions, serializers, status, views
+ HTTP_HEADER_ENCODING,
+ RemovedInDRF310Warning,
+ authentication,
+ generics,
+ permissions,
+ serializers,
+ status,
+ views,
)
from rest_framework.compat import PY36, is_guardian_installed, mock
from rest_framework.filters import DjangoObjectPermissionsFilter
@@ -21,13 +27,14 @@ from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory
from tests.models import BasicModel
+
factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
- fields = '__all__'
+ fields = "__all__"
class RootView(generics.ListCreateAPIView):
@@ -68,35 +75,47 @@ empty_list_view = EmptyListView.as_view()
def basic_auth_header(username, password):
- credentials = ('%s:%s' % (username, password))
- base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
- return 'Basic %s' % base64_credentials
+ credentials = "%s:%s" % (username, password)
+ base64_credentials = base64.b64encode(
+ credentials.encode(HTTP_HEADER_ENCODING)
+ ).decode(HTTP_HEADER_ENCODING)
+ return "Basic %s" % base64_credentials
class ModelPermissionsIntegrationTests(TestCase):
def setUp(self):
- User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
- user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
- user.user_permissions.set([
- Permission.objects.get(codename='add_basicmodel'),
- Permission.objects.get(codename='change_basicmodel'),
- Permission.objects.get(codename='delete_basicmodel')
- ])
+ User.objects.create_user("disallowed", "disallowed@example.com", "password")
+ user = User.objects.create_user(
+ "permitted", "permitted@example.com", "password"
+ )
+ user.user_permissions.set(
+ [
+ Permission.objects.get(codename="add_basicmodel"),
+ Permission.objects.get(codename="change_basicmodel"),
+ Permission.objects.get(codename="delete_basicmodel"),
+ ]
+ )
- user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
- user.user_permissions.set([
- Permission.objects.get(codename='change_basicmodel'),
- ])
+ user = User.objects.create_user(
+ "updateonly", "updateonly@example.com", "password"
+ )
+ user.user_permissions.set(
+ [Permission.objects.get(codename="change_basicmodel")]
+ )
- self.permitted_credentials = basic_auth_header('permitted', 'password')
- self.disallowed_credentials = basic_auth_header('disallowed', 'password')
- self.updateonly_credentials = basic_auth_header('updateonly', 'password')
+ self.permitted_credentials = basic_auth_header("permitted", "password")
+ self.disallowed_credentials = basic_auth_header("disallowed", "password")
+ self.updateonly_credentials = basic_auth_header("updateonly", "password")
- BasicModel(text='foo').save()
+ BasicModel(text="foo").save()
def test_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
+ request = factory.post(
+ "/",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.permitted_credentials,
+ )
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
@@ -106,121 +125,125 @@ class ModelPermissionsIntegrationTests(TestCase):
apply to APIRoot view. More specifically we check expected behavior of
``_ignore_model_permissions`` attribute support.
"""
- request = factory.get('/', format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- request.resolver_match = ResolverMatch('get', (), {})
+ request = factory.get(
+ "/", format="json", HTTP_AUTHORIZATION=self.permitted_credentials
+ )
+ request.resolver_match = ResolverMatch("get", (), {})
response = api_root_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_get_queryset_has_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
+ request = factory.post(
+ "/",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.permitted_credentials,
+ )
response = get_queryset_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
def test_has_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
+ request = factory.put(
+ "/1",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.permitted_credentials,
+ )
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_has_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ request = factory.delete("/1", HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_does_not_have_create_permissions(self):
- request = factory.post('/', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
+ request = factory.post(
+ "/",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.disallowed_credentials,
+ )
response = root_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_put_permissions(self):
- request = factory.put('/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.disallowed_credentials)
- response = instance_view(request, pk='1')
+ request = factory.put(
+ "/1",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.disallowed_credentials,
+ )
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
+ request = factory.delete("/1", HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.permitted_credentials
- )
- response = root_view(request, pk='1')
+ request = factory.options("/", HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = root_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['POST'])
+ self.assertIn("actions", response.data)
+ self.assertEqual(list(response.data["actions"]), ["POST"])
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.permitted_credentials
- )
- response = instance_view(request, pk='1')
+ request = factory.options("/1", HTTP_AUTHORIZATION=self.permitted_credentials)
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['PUT'])
+ self.assertIn("actions", response.data)
+ self.assertEqual(list(response.data["actions"]), ["PUT"])
def test_options_disallowed(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.disallowed_credentials
- )
- response = root_view(request, pk='1')
+ request = factory.options("/", HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = root_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
+ self.assertNotIn("actions", response.data)
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.disallowed_credentials
- )
- response = instance_view(request, pk='1')
+ request = factory.options("/1", HTTP_AUTHORIZATION=self.disallowed_credentials)
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
+ self.assertNotIn("actions", response.data)
def test_options_updateonly(self):
- request = factory.options(
- '/',
- HTTP_AUTHORIZATION=self.updateonly_credentials
- )
- response = root_view(request, pk='1')
+ request = factory.options("/", HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = root_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertNotIn('actions', response.data)
+ self.assertNotIn("actions", response.data)
- request = factory.options(
- '/1',
- HTTP_AUTHORIZATION=self.updateonly_credentials
- )
- response = instance_view(request, pk='1')
+ request = factory.options("/1", HTTP_AUTHORIZATION=self.updateonly_credentials)
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertIn('actions', response.data)
- self.assertEqual(list(response.data['actions']), ['PUT'])
+ self.assertIn("actions", response.data)
+ self.assertEqual(list(response.data["actions"]), ["PUT"])
def test_empty_view_does_not_assert(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
+ request = factory.get("/1", HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_calling_method_not_allowed(self):
- request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.permitted_credentials)
+ request = factory.generic(
+ "METHOD_NOT_ALLOWED", "/", HTTP_AUTHORIZATION=self.permitted_credentials
+ )
response = root_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
- request = factory.generic('METHOD_NOT_ALLOWED', '/1', HTTP_AUTHORIZATION=self.permitted_credentials)
- response = instance_view(request, pk='1')
+ request = factory.generic(
+ "METHOD_NOT_ALLOWED", "/1", HTTP_AUTHORIZATION=self.permitted_credentials
+ )
+ response = instance_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_check_auth_before_queryset_call(self):
class View(RootView):
def get_queryset(_):
- self.fail('should not reach due to auth check')
+ self.fail("should not reach due to auth check")
+
view = View.as_view()
- request = factory.get('/', HTTP_AUTHORIZATION='')
+ request = factory.get("/", HTTP_AUTHORIZATION="")
response = view(request)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
@@ -228,10 +251,11 @@ class ModelPermissionsIntegrationTests(TestCase):
class View(views.APIView):
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions]
+
view = View.as_view()
- request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
- msg = 'Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method.'
+ request = factory.get("/", HTTP_AUTHORIZATION=self.permitted_credentials)
+ msg = "Cannot apply DjangoModelPermissions on a view that does not set `.queryset` or have a `.get_queryset()` method."
with self.assertRaisesMessage(AssertionError, msg):
view(request)
@@ -239,10 +263,13 @@ class ModelPermissionsIntegrationTests(TestCase):
class View(RootView):
def get_queryset(self):
return None
+
view = View.as_view()
- request = factory.get('/', HTTP_AUTHORIZATION=self.permitted_credentials)
- with self.assertRaisesMessage(AssertionError, 'View.get_queryset() returned None'):
+ request = factory.get("/", HTTP_AUTHORIZATION=self.permitted_credentials)
+ with self.assertRaisesMessage(
+ AssertionError, "View.get_queryset() returned None"
+ ):
view(request)
@@ -250,11 +277,11 @@ class BasicPermModel(models.Model):
text = models.CharField(max_length=100)
class Meta:
- app_label = 'tests'
+ app_label = "tests"
if django.VERSION < (2, 1):
permissions = (
- ('view_basicpermmodel', 'Can view basic perm model'),
+ ("view_basicpermmodel", "Can view basic perm model"),
# add, change, delete built in to django
)
@@ -262,19 +289,19 @@ class BasicPermModel(models.Model):
class BasicPermSerializer(serializers.ModelSerializer):
class Meta:
model = BasicPermModel
- fields = '__all__'
+ fields = "__all__"
# Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = {
- 'GET': ['%(app_label)s.view_%(model_name)s'],
- 'OPTIONS': ['%(app_label)s.view_%(model_name)s'],
- 'HEAD': ['%(app_label)s.view_%(model_name)s'],
- 'POST': ['%(app_label)s.add_%(model_name)s'],
- 'PUT': ['%(app_label)s.change_%(model_name)s'],
- 'PATCH': ['%(app_label)s.change_%(model_name)s'],
- 'DELETE': ['%(app_label)s.delete_%(model_name)s'],
+ "GET": ["%(app_label)s.view_%(model_name)s"],
+ "OPTIONS": ["%(app_label)s.view_%(model_name)s"],
+ "HEAD": ["%(app_label)s.view_%(model_name)s"],
+ "POST": ["%(app_label)s.add_%(model_name)s"],
+ "PUT": ["%(app_label)s.change_%(model_name)s"],
+ "PATCH": ["%(app_label)s.change_%(model_name)s"],
+ "DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
@@ -310,103 +337,114 @@ class GetQuerysetObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIV
get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.as_view()
-@unittest.skipUnless(is_guardian_installed(), 'django-guardian not installed')
+@unittest.skipUnless(is_guardian_installed(), "django-guardian not installed")
class ObjectPermissionsIntegrationTests(TestCase):
"""
Integration tests for the object level permissions API.
"""
+
def setUp(self):
from guardian.shortcuts import assign_perm
# create users
create = User.objects.create_user
users = {
- 'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
- 'readonly': create('readonly', 'readonly@example.com', 'password'),
- 'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
- 'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
+ "fullaccess": create("fullaccess", "fullaccess@example.com", "password"),
+ "readonly": create("readonly", "readonly@example.com", "password"),
+ "writeonly": create("writeonly", "writeonly@example.com", "password"),
+ "deleteonly": create("deleteonly", "deleteonly@example.com", "password"),
}
# give everyone model level permissions, as we are not testing those
- everyone = Group.objects.create(name='everyone')
+ everyone = Group.objects.create(name="everyone")
model_name = BasicPermModel._meta.model_name
app_label = BasicPermModel._meta.app_label
- f = '{0}_{1}'.format
+ f = "{0}_{1}".format
perms = {
- 'view': f('view', model_name),
- 'change': f('change', model_name),
- 'delete': f('delete', model_name)
+ "view": f("view", model_name),
+ "change": f("change", model_name),
+ "delete": f("delete", model_name),
}
for perm in perms.values():
- perm = '{0}.{1}'.format(app_label, perm)
+ perm = "{0}.{1}".format(app_label, perm)
assign_perm(perm, everyone)
everyone.user_set.add(*users.values())
# appropriate object level permissions
- readers = Group.objects.create(name='readers')
- writers = Group.objects.create(name='writers')
- deleters = Group.objects.create(name='deleters')
+ readers = Group.objects.create(name="readers")
+ writers = Group.objects.create(name="writers")
+ deleters = Group.objects.create(name="deleters")
- model = BasicPermModel.objects.create(text='foo')
+ model = BasicPermModel.objects.create(text="foo")
- assign_perm(perms['view'], readers, model)
- assign_perm(perms['change'], writers, model)
- assign_perm(perms['delete'], deleters, model)
+ assign_perm(perms["view"], readers, model)
+ assign_perm(perms["change"], writers, model)
+ assign_perm(perms["delete"], deleters, model)
- readers.user_set.add(users['fullaccess'], users['readonly'])
- writers.user_set.add(users['fullaccess'], users['writeonly'])
- deleters.user_set.add(users['fullaccess'], users['deleteonly'])
+ readers.user_set.add(users["fullaccess"], users["readonly"])
+ writers.user_set.add(users["fullaccess"], users["writeonly"])
+ deleters.user_set.add(users["fullaccess"], users["deleteonly"])
self.credentials = {}
for user in users.values():
- self.credentials[user.username] = basic_auth_header(user.username, 'password')
+ self.credentials[user.username] = basic_auth_header(
+ user.username, "password"
+ )
# Delete
def test_can_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
- response = object_permissions_view(request, pk='1')
+ request = factory.delete(
+ "/1", HTTP_AUTHORIZATION=self.credentials["deleteonly"]
+ )
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_cannot_delete_permissions(self):
- request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
+ request = factory.delete("/1", HTTP_AUTHORIZATION=self.credentials["readonly"])
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# Update
def test_can_update_permissions(self):
request = factory.patch(
- '/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['writeonly']
+ "/1",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.credentials["writeonly"],
)
- response = object_permissions_view(request, pk='1')
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data.get('text'), 'foobar')
+ self.assertEqual(response.data.get("text"), "foobar")
def test_cannot_update_permissions(self):
request = factory.patch(
- '/1', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly']
+ "/1",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.credentials["deleteonly"],
)
- response = object_permissions_view(request, pk='1')
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_cannot_update_permissions_non_existing(self):
request = factory.patch(
- '/999', {'text': 'foobar'}, format='json',
- HTTP_AUTHORIZATION=self.credentials['deleteonly']
+ "/999",
+ {"text": "foobar"},
+ format="json",
+ HTTP_AUTHORIZATION=self.credentials["deleteonly"],
)
- response = object_permissions_view(request, pk='999')
+ response = object_permissions_view(request, pk="999")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Read
def test_can_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = object_permissions_view(request, pk='1')
+ request = factory.get("/1", HTTP_AUTHORIZATION=self.credentials["readonly"])
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_cannot_read_permissions(self):
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- response = object_permissions_view(request, pk='1')
+ request = factory.get("/1", HTTP_AUTHORIZATION=self.credentials["writeonly"])
+ response = object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_can_read_get_queryset_permissions(self):
@@ -414,8 +452,8 @@ class ObjectPermissionsIntegrationTests(TestCase):
same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``.
"""
- request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
- response = get_queryset_object_permissions_view(request, pk='1')
+ request = factory.get("/1", HTTP_AUTHORIZATION=self.credentials["readonly"])
+ response = get_queryset_object_permissions_view(request, pk="1")
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Read list
@@ -424,25 +462,31 @@ class ObjectPermissionsIntegrationTests(TestCase):
warnings.simplefilter("always")
DjangoObjectPermissionsFilter()
- message = ("`DjangoObjectPermissionsFilter` has been deprecated and moved "
- "to the 3rd-party django-rest-framework-guardian package.")
+ message = (
+ "`DjangoObjectPermissionsFilter` has been deprecated and moved "
+ "to the 3rd-party django-rest-framework-guardian package."
+ )
self.assertEqual(len(w), 1)
self.assertIs(w[-1].category, RemovedInDRF310Warning)
self.assertEqual(str(w[-1].message), message)
def test_can_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ request = factory.get("/", HTTP_AUTHORIZATION=self.credentials["readonly"])
+ object_permissions_list_view.cls.filter_backends = (
+ DjangoObjectPermissionsFilter,
+ )
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
- self.assertEqual(response.data[0].get('id'), 1)
+ self.assertEqual(response.data[0].get("id"), 1)
def test_cannot_read_list_permissions(self):
- request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
- object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
+ request = factory.get("/", HTTP_AUTHORIZATION=self.credentials["writeonly"])
+ object_permissions_list_view.cls.filter_backends = (
+ DjangoObjectPermissionsFilter,
+ )
# TODO: remove in version 3.10
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
@@ -451,7 +495,9 @@ class ObjectPermissionsIntegrationTests(TestCase):
self.assertListEqual(response.data, [])
def test_cannot_method_not_allowed(self):
- request = factory.generic('METHOD_NOT_ALLOWED', '/', HTTP_AUTHORIZATION=self.credentials['readonly'])
+ request = factory.generic(
+ "METHOD_NOT_ALLOWED", "/", HTTP_AUTHORIZATION=self.credentials["readonly"]
+ )
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
@@ -462,7 +508,7 @@ class BasicPerm(permissions.BasePermission):
class BasicPermWithDetail(permissions.BasePermission):
- message = 'Custom: You cannot access this resource'
+ message = "Custom: You cannot access this resource"
def has_permission(self, request, view):
return False
@@ -474,7 +520,7 @@ class BasicObjectPerm(permissions.BasePermission):
class BasicObjectPermWithDetail(permissions.BasePermission):
- message = 'Custom: You cannot access this resource'
+ message = "Custom: You cannot access this resource"
def has_object_permission(self, request, view, obj):
return False
@@ -512,135 +558,136 @@ denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
class CustomPermissionsTests(TestCase):
def setUp(self):
- BasicModel(text='foo').save()
- User.objects.create_user('username', 'username@example.com', 'password')
- credentials = basic_auth_header('username', 'password')
- self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
- self.custom_message = 'Custom: You cannot access this resource'
+ BasicModel(text="foo").save()
+ User.objects.create_user("username", "username@example.com", "password")
+ credentials = basic_auth_header("username", "password")
+ self.request = factory.get("/1", format="json", HTTP_AUTHORIZATION=credentials)
+ self.custom_message = "Custom: You cannot access this resource"
def test_permission_denied(self):
- response = denied_view(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertNotEqual(detail, self.custom_message)
+ response = denied_view(self.request, pk=1)
+ detail = response.data.get("detail")
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_with_custom_detail(self):
- response = denied_view_with_detail(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(detail, self.custom_message)
+ response = denied_view_with_detail(self.request, pk=1)
+ detail = response.data.get("detail")
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(detail, self.custom_message)
def test_permission_denied_for_object(self):
- response = denied_object_view(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertNotEqual(detail, self.custom_message)
+ response = denied_object_view(self.request, pk=1)
+ detail = response.data.get("detail")
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_for_object_with_custom_detail(self):
- response = denied_object_view_with_detail(self.request, pk=1)
- detail = response.data.get('detail')
- self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
- self.assertEqual(detail, self.custom_message)
+ response = denied_object_view_with_detail(self.request, pk=1)
+ detail = response.data.get("detail")
+ self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
+ self.assertEqual(detail, self.custom_message)
class PermissionsCompositionTests(TestCase):
-
def setUp(self):
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
- self.user = User.objects.create_user(
- self.username,
- self.email,
- self.password
- )
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
+ self.user = User.objects.create_user(self.username, self.email, self.password)
self.client.login(username=self.username, password=self.password)
def test_and_false(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
def test_and_true(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = permissions.IsAuthenticated & permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_or_false(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_or_true(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = permissions.IsAuthenticated | permissions.AllowAny
assert composed_perm().has_permission(request, None) is True
def test_not_false(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
composed_perm = ~permissions.IsAuthenticated
assert composed_perm().has_permission(request, None) is True
def test_not_true(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = ~permissions.AllowAny
assert composed_perm().has_permission(request, None) is False
def test_several_levels_without_negation(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = (
- permissions.IsAuthenticated &
- permissions.IsAuthenticated &
- permissions.IsAuthenticated &
permissions.IsAuthenticated
+ & permissions.IsAuthenticated
+ & permissions.IsAuthenticated
+ & permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence_with_negation(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = (
- permissions.IsAuthenticated &
- ~ permissions.IsAdminUser &
- permissions.IsAuthenticated &
- ~(permissions.IsAdminUser & permissions.IsAdminUser)
+ permissions.IsAuthenticated
+ & ~permissions.IsAdminUser
+ & permissions.IsAuthenticated
+ & ~(permissions.IsAdminUser & permissions.IsAdminUser)
)
assert composed_perm().has_permission(request, None) is True
def test_several_levels_and_precedence(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = self.user
composed_perm = (
- permissions.IsAuthenticated &
- permissions.IsAuthenticated |
- permissions.IsAuthenticated &
- permissions.IsAuthenticated
+ permissions.IsAuthenticated & permissions.IsAuthenticated
+ | permissions.IsAuthenticated & permissions.IsAuthenticated
)
assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_or_lazyness(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
+ with mock.patch.object(
+ permissions.AllowAny, "has_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.AllowAny | permissions.IsAuthenticated
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True)
mock_allow.assert_called_once()
mock_deny.assert_not_called()
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
+ with mock.patch.object(
+ permissions.AllowAny, "has_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.IsAuthenticated | permissions.AllowAny
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
@@ -648,20 +695,28 @@ class PermissionsCompositionTests(TestCase):
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_or_lazyness(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
+ with mock.patch.object(
+ permissions.AllowAny, "has_object_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_object_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.AllowAny | permissions.IsAuthenticated
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True)
mock_allow.assert_called_once()
mock_deny.assert_not_called()
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
+ with mock.patch.object(
+ permissions.AllowAny, "has_object_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_object_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.IsAuthenticated | permissions.AllowAny
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True)
mock_deny.assert_called_once()
@@ -669,20 +724,28 @@ class PermissionsCompositionTests(TestCase):
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_and_lazyness(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
+ with mock.patch.object(
+ permissions.AllowAny, "has_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.AllowAny & permissions.IsAuthenticated
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False)
mock_allow.assert_called_once()
mock_deny.assert_called_once()
- with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
+ with mock.patch.object(
+ permissions.AllowAny, "has_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.IsAuthenticated & permissions.AllowAny
hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
@@ -690,20 +753,28 @@ class PermissionsCompositionTests(TestCase):
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_and_lazyness(self):
- request = factory.get('/1', format='json')
+ request = factory.get("/1", format="json")
request.user = AnonymousUser()
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
+ with mock.patch.object(
+ permissions.AllowAny, "has_object_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_object_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.AllowAny & permissions.IsAuthenticated
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False)
mock_allow.assert_called_once()
mock_deny.assert_called_once()
- with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
- with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
- composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
+ with mock.patch.object(
+ permissions.AllowAny, "has_object_permission", return_value=True
+ ) as mock_allow:
+ with mock.patch.object(
+ permissions.IsAuthenticated, "has_object_permission", return_value=False
+ ) as mock_deny:
+ composed_perm = permissions.IsAuthenticated & permissions.AllowAny
hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False)
mock_allow.assert_not_called()
diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py
index b07087c97..ae4004f80 100644
--- a/tests/test_prefetch_related.py
+++ b/tests/test_prefetch_related.py
@@ -4,38 +4,41 @@ from django.test import TestCase
from rest_framework import generics, serializers
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
- fields = ('id', 'username', 'email', 'groups')
+ fields = ("id", "username", "email", "groups")
class UserUpdate(generics.UpdateAPIView):
- queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
+ queryset = User.objects.exclude(username="exclude").prefetch_related("groups")
serializer_class = UserSerializer
class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
- self.user = User.objects.create(username='tom', email='tom@example.com')
- self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
+ self.user = User.objects.create(username="tom", email="tom@example.com")
+ self.groups = [Group.objects.create(name="a"), Group.objects.create(name="b")]
self.user.groups.set(self.groups)
def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
- request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
+ request = factory.put(
+ "/", {"username": "new", "groups": [groups_pk]}, format="json"
+ )
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
- 'id': pk,
- 'username': 'new',
- 'groups': [1],
- 'email': 'tom@example.com'
+ "id": pk,
+ "username": "new",
+ "groups": [1],
+ "email": "tom@example.com",
}
assert response.data == expected
@@ -46,13 +49,15 @@ class TestPrefetchRelatedUpdates(TestCase):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
- request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
+ request = factory.put(
+ "/", {"username": "exclude", "groups": [groups_pk]}, format="json"
+ )
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
- 'id': pk,
- 'username': 'exclude',
- 'groups': [1],
- 'email': 'tom@example.com'
+ "id": pk,
+ "username": "exclude",
+ "groups": [1],
+ "email": "tom@example.com",
}
assert response.data == expected
diff --git a/tests/test_relations.py b/tests/test_relations.py
index 3c4b7d90b..c38d6cbca 100644
--- a/tests/test_relations.py
+++ b/tests/test_relations.py
@@ -11,19 +11,17 @@ from rest_framework import relations, serializers
from rest_framework.fields import empty
from rest_framework.test import APISimpleTestCase
-from .utils import (
- BadType, MockObject, MockQueryset, fail_reverse, mock_reverse
-)
+from .utils import BadType, MockObject, MockQueryset, fail_reverse, mock_reverse
class TestStringRelatedField(APISimpleTestCase):
def setUp(self):
- self.instance = MockObject(pk=1, name='foo')
+ self.instance = MockObject(pk=1, name="foo")
self.field = serializers.StringRelatedField()
def test_string_related_representation(self):
representation = self.field.to_representation(self.instance)
- assert representation == ''
+ assert representation == ""
class MockApiSettings(object):
@@ -34,48 +32,54 @@ class MockApiSettings(object):
class TestRelatedFieldHTMLCutoff(APISimpleTestCase):
def setUp(self):
- self.queryset = MockQueryset([
- MockObject(pk=i, name=str(i)) for i in range(0, 1100)
- ])
+ self.queryset = MockQueryset(
+ [MockObject(pk=i, name=str(i)) for i in range(0, 1100)]
+ )
self.monkeypatch = MonkeyPatch()
def test_no_settings(self):
# The default is 1,000, so sans settings it should be 1,000 plus one.
for many in (False, True):
- field = serializers.PrimaryKeyRelatedField(queryset=self.queryset,
- many=many)
+ field = serializers.PrimaryKeyRelatedField(
+ queryset=self.queryset, many=many
+ )
options = list(field.iter_options())
assert len(options) == 1001
assert options[-1].display_text == "More than 1000 items..."
def test_settings_cutoff(self):
- self.monkeypatch.setattr(relations, "api_settings",
- MockApiSettings(2, "Cut Off"))
+ self.monkeypatch.setattr(
+ relations, "api_settings", MockApiSettings(2, "Cut Off")
+ )
for many in (False, True):
- field = serializers.PrimaryKeyRelatedField(queryset=self.queryset,
- many=many)
+ field = serializers.PrimaryKeyRelatedField(
+ queryset=self.queryset, many=many
+ )
options = list(field.iter_options())
assert len(options) == 3 # 2 real items plus the 'Cut Off' item.
assert options[-1].display_text == "Cut Off"
def test_settings_cutoff_none(self):
# Setting it to None should mean no limit; the default limit is 1,000.
- self.monkeypatch.setattr(relations, "api_settings",
- MockApiSettings(None, "Cut Off"))
+ self.monkeypatch.setattr(
+ relations, "api_settings", MockApiSettings(None, "Cut Off")
+ )
for many in (False, True):
- field = serializers.PrimaryKeyRelatedField(queryset=self.queryset,
- many=many)
+ field = serializers.PrimaryKeyRelatedField(
+ queryset=self.queryset, many=many
+ )
options = list(field.iter_options())
assert len(options) == 1100
def test_settings_kwargs_cutoff(self):
# The explicit argument should override the settings.
- self.monkeypatch.setattr(relations, "api_settings",
- MockApiSettings(2, "Cut Off"))
+ self.monkeypatch.setattr(
+ relations, "api_settings", MockApiSettings(2, "Cut Off")
+ )
for many in (False, True):
- field = serializers.PrimaryKeyRelatedField(queryset=self.queryset,
- many=many,
- html_cutoff=100)
+ field = serializers.PrimaryKeyRelatedField(
+ queryset=self.queryset, many=many, html_cutoff=100
+ )
options = list(field.iter_options())
assert len(options) == 101
assert options[-1].display_text == "Cut Off"
@@ -83,11 +87,13 @@ class TestRelatedFieldHTMLCutoff(APISimpleTestCase):
class TestPrimaryKeyRelatedField(APISimpleTestCase):
def setUp(self):
- self.queryset = MockQueryset([
- MockObject(pk=1, name='foo'),
- MockObject(pk=2, name='bar'),
- MockObject(pk=3, name='baz')
- ])
+ self.queryset = MockQueryset(
+ [
+ MockObject(pk=1, name="foo"),
+ MockObject(pk=2, name="bar"),
+ MockObject(pk=3, name="baz"),
+ ]
+ )
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset)
@@ -105,7 +111,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
- assert msg == 'Incorrect type. Expected pk value, received BadType.'
+ assert msg == "Incorrect type. Expected pk value, received BadType."
def test_pk_representation(self):
representation = self.field.to_representation(self.instance)
@@ -119,15 +125,16 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
def setUp(self):
- self.queryset = MockQueryset([
- MockObject(pk=uuid.UUID(int=0), name='foo'),
- MockObject(pk=uuid.UUID(int=1), name='bar'),
- MockObject(pk=uuid.UUID(int=2), name='baz')
- ])
+ self.queryset = MockQueryset(
+ [
+ MockObject(pk=uuid.UUID(int=0), name="foo"),
+ MockObject(pk=uuid.UUID(int=1), name="bar"),
+ MockObject(pk=uuid.UUID(int=2), name="baz"),
+ ]
+ )
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(
- queryset=self.queryset,
- pk_field=serializers.UUIDField(format='int')
+ queryset=self.queryset, pk_field=serializers.UUIDField(format="int")
)
def test_pk_related_lookup_exists(self):
@@ -138,41 +145,41 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.detail[0]
- assert msg == 'Invalid pk "00000000-0000-0000-0000-000000000004" - object does not exist.'
+ assert (
+ msg
+ == 'Invalid pk "00000000-0000-0000-0000-000000000004" - object does not exist.'
+ )
def test_pk_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.pk.int
-@override_settings(ROOT_URLCONF=[
- url(r'^example/(?P.+)/$', lambda: None, name='example'),
-])
+@override_settings(
+ ROOT_URLCONF=[url(r"^example/(?P.+)/$", lambda: None, name="example")]
+)
class TestHyperlinkedRelatedField(APISimpleTestCase):
def setUp(self):
- self.queryset = MockQueryset([
- MockObject(pk=1, name='foobar'),
- MockObject(pk=2, name='bazABCqux'),
- ])
+ self.queryset = MockQueryset(
+ [MockObject(pk=1, name="foobar"), MockObject(pk=2, name="bazABCqux")]
+ )
self.field = serializers.HyperlinkedRelatedField(
- view_name='example',
- lookup_field='name',
- lookup_url_kwarg='name',
+ view_name="example",
+ lookup_field="name",
+ lookup_url_kwarg="name",
queryset=self.queryset,
)
self.field.reverse = mock_reverse
- self.field._context = {'request': True}
+ self.field._context = {"request": True}
def test_representation_unsaved_object_with_non_nullable_pk(self):
- representation = self.field.to_representation(MockObject(pk=''))
+ representation = self.field.to_representation(MockObject(pk=""))
assert representation is None
def test_serialize_empty_relationship_attribute(self):
class TestSerializer(serializers.Serializer):
via_unreachable = serializers.HyperlinkedRelatedField(
- source='does_not_exist.unreachable',
- view_name='example',
- read_only=True,
+ source="does_not_exist.unreachable", view_name="example", read_only=True
)
class TestSerializable:
@@ -181,30 +188,32 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
raise ObjectDoesNotExist
serializer = TestSerializer(TestSerializable())
- assert serializer.data == {'via_unreachable': None}
+ assert serializer.data == {"via_unreachable": None}
def test_hyperlinked_related_lookup_exists(self):
- instance = self.field.to_internal_value('http://example.org/example/foobar/')
+ instance = self.field.to_internal_value("http://example.org/example/foobar/")
assert instance is self.queryset.items[0]
def test_hyperlinked_related_lookup_url_encoded_exists(self):
- instance = self.field.to_internal_value('http://example.org/example/baz%41%42%43qux/')
+ instance = self.field.to_internal_value(
+ "http://example.org/example/baz%41%42%43qux/"
+ )
assert instance is self.queryset.items[1]
def test_hyperlinked_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
- self.field.to_internal_value('http://example.org/example/doesnotexist/')
+ self.field.to_internal_value("http://example.org/example/doesnotexist/")
msg = excinfo.value.detail[0]
- assert msg == 'Invalid hyperlink - Object does not exist.'
+ assert msg == "Invalid hyperlink - Object does not exist."
def test_hyperlinked_related_internal_type_error(self):
class Field(serializers.HyperlinkedRelatedField):
def get_object(self, incorrect, signature):
raise NotImplementedError()
- field = Field(view_name='example', queryset=self.queryset)
+ field = Field(view_name="example", queryset=self.queryset)
with pytest.raises(TypeError):
- field.to_internal_value('http://example.org/example/doesnotexist/')
+ field.to_internal_value("http://example.org/example/doesnotexist/")
def hyperlinked_related_queryset_error(self, exc_type):
class QuerySet:
@@ -212,14 +221,12 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
raise exc_type
field = serializers.HyperlinkedRelatedField(
- view_name='example',
- lookup_field='name',
- queryset=QuerySet(),
+ view_name="example", lookup_field="name", queryset=QuerySet()
)
with pytest.raises(serializers.ValidationError) as excinfo:
- field.to_internal_value('http://example.org/example/doesnotexist/')
+ field.to_internal_value("http://example.org/example/doesnotexist/")
msg = excinfo.value.detail[0]
- assert msg == 'Invalid hyperlink - Object does not exist.'
+ assert msg == "Invalid hyperlink - Object does not exist."
def test_hyperlinked_related_queryset_type_error(self):
self.hyperlinked_related_queryset_error(TypeError)
@@ -230,23 +237,23 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
class TestHyperlinkedIdentityField(APISimpleTestCase):
def setUp(self):
- self.instance = MockObject(pk=1, name='foo')
- self.field = serializers.HyperlinkedIdentityField(view_name='example')
+ self.instance = MockObject(pk=1, name="foo")
+ self.field = serializers.HyperlinkedIdentityField(view_name="example")
self.field.reverse = mock_reverse
- self.field._context = {'request': True}
+ self.field._context = {"request": True}
def test_representation(self):
representation = self.field.to_representation(self.instance)
- assert representation == 'http://example.org/example/1/'
+ assert representation == "http://example.org/example/1/"
def test_representation_unsaved_object(self):
representation = self.field.to_representation(MockObject(pk=None))
assert representation is None
def test_representation_with_format(self):
- self.field._context['format'] = 'xml'
+ self.field._context["format"] = "xml"
representation = self.field.to_representation(self.instance)
- assert representation == 'http://example.org/example/1.xml/'
+ assert representation == "http://example.org/example/1.xml/"
def test_improperly_configured(self):
"""
@@ -270,31 +277,35 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
"""
def setUp(self):
- self.instance = MockObject(pk=1, name='foo')
- self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
+ self.instance = MockObject(pk=1, name="foo")
+ self.field = serializers.HyperlinkedIdentityField(
+ view_name="example", format="json"
+ )
self.field.reverse = mock_reverse
- self.field._context = {'request': True}
+ self.field._context = {"request": True}
def test_representation(self):
representation = self.field.to_representation(self.instance)
- assert representation == 'http://example.org/example/1/'
+ assert representation == "http://example.org/example/1/"
def test_representation_with_format(self):
- self.field._context['format'] = 'xml'
+ self.field._context["format"] = "xml"
representation = self.field.to_representation(self.instance)
- assert representation == 'http://example.org/example/1.json/'
+ assert representation == "http://example.org/example/1.json/"
class TestSlugRelatedField(APISimpleTestCase):
def setUp(self):
- self.queryset = MockQueryset([
- MockObject(pk=1, name='foo'),
- MockObject(pk=2, name='bar'),
- MockObject(pk=3, name='baz')
- ])
+ self.queryset = MockQueryset(
+ [
+ MockObject(pk=1, name="foo"),
+ MockObject(pk=2, name="bar"),
+ MockObject(pk=3, name="baz"),
+ ]
+ )
self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField(
- slug_field='name', queryset=self.queryset
+ slug_field="name", queryset=self.queryset
)
def test_slug_related_lookup_exists(self):
@@ -303,15 +314,15 @@ class TestSlugRelatedField(APISimpleTestCase):
def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
- self.field.to_internal_value('doesnotexist')
+ self.field.to_internal_value("doesnotexist")
msg = excinfo.value.detail[0]
- assert msg == 'Object with name=doesnotexist does not exist.'
+ assert msg == "Object with name=doesnotexist does not exist."
def test_slug_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
- assert msg == 'Invalid value.'
+ assert msg == "Invalid value."
def test_representation(self):
representation = self.field.to_representation(self.instance)
@@ -324,47 +335,48 @@ class TestSlugRelatedField(APISimpleTestCase):
def get_queryset(self):
return qs
- field = NoQuerySetSlugRelatedField(slug_field='name')
+ field = NoQuerySetSlugRelatedField(slug_field="name")
field.to_internal_value(self.instance.name)
class TestManyRelatedField(APISimpleTestCase):
def setUp(self):
- self.instance = MockObject(pk=1, name='foo')
+ self.instance = MockObject(pk=1, name="foo")
self.field = serializers.StringRelatedField(many=True)
- self.field.field_name = 'foo'
+ self.field.field_name = "foo"
def test_get_value_regular_dictionary_full(self):
- assert 'bar' == self.field.get_value({'foo': 'bar'})
- assert empty == self.field.get_value({'baz': 'bar'})
+ assert "bar" == self.field.get_value({"foo": "bar"})
+ assert empty == self.field.get_value({"baz": "bar"})
def test_get_value_regular_dictionary_partial(self):
- setattr(self.field.root, 'partial', True)
- assert 'bar' == self.field.get_value({'foo': 'bar'})
- assert empty == self.field.get_value({'baz': 'bar'})
+ setattr(self.field.root, "partial", True)
+ assert "bar" == self.field.get_value({"foo": "bar"})
+ assert empty == self.field.get_value({"baz": "bar"})
def test_get_value_multi_dictionary_full(self):
- mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
- assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+ mvd = MultiValueDict({"foo": ["bar1", "bar2"]})
+ assert ["bar1", "bar2"] == self.field.get_value(mvd)
- mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ mvd = MultiValueDict({"baz": ["bar1", "bar2"]})
assert [] == self.field.get_value(mvd)
def test_get_value_multi_dictionary_partial(self):
- setattr(self.field.root, 'partial', True)
- mvd = MultiValueDict({'foo': ['bar1', 'bar2']})
- assert ['bar1', 'bar2'] == self.field.get_value(mvd)
+ setattr(self.field.root, "partial", True)
+ mvd = MultiValueDict({"foo": ["bar1", "bar2"]})
+ assert ["bar1", "bar2"] == self.field.get_value(mvd)
- mvd = MultiValueDict({'baz': ['bar1', 'bar2']})
+ mvd = MultiValueDict({"baz": ["bar1", "bar2"]})
assert empty == self.field.get_value(mvd)
class TestHyperlink:
def setup(self):
- self.default_hyperlink = serializers.Hyperlink('http://example.com', 'test')
+ self.default_hyperlink = serializers.Hyperlink("http://example.com", "test")
def test_can_be_pickled(self):
import pickle
+
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
assert upkled == self.default_hyperlink
assert upkled.name == self.default_hyperlink.name
diff --git a/tests/test_relations_hyperlink.py b/tests/test_relations_hyperlink.py
index 887a6f423..7204a60eb 100644
--- a/tests/test_relations_hyperlink.py
+++ b/tests/test_relations_hyperlink.py
@@ -6,12 +6,18 @@ from django.test import TestCase, override_settings
from rest_framework import serializers
from rest_framework.test import APIRequestFactory
from tests.models import (
- ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget,
- NullableForeignKeySource, NullableOneToOneSource, OneToOneTarget
+ ForeignKeySource,
+ ForeignKeyTarget,
+ ManyToManySource,
+ ManyToManyTarget,
+ NullableForeignKeySource,
+ NullableOneToOneSource,
+ OneToOneTarget,
)
+
factory = APIRequestFactory()
-request = factory.get('/') # Just to ensure we have a request in the serializer context
+request = factory.get("/") # Just to ensure we have a request in the serializer context
def dummy_view(request, pk):
@@ -19,14 +25,38 @@ def dummy_view(request, pk):
urlpatterns = [
- url(r'^dummyurl/(?P[0-9]+)/$', dummy_view, name='dummy-url'),
- url(r'^manytomanysource/(?P[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
- url(r'^manytomanytarget/(?P[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
- url(r'^foreignkeysource/(?P[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
- url(r'^foreignkeytarget/(?P[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
- url(r'^nullableforeignkeysource/(?P[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
- url(r'^onetoonetarget/(?P[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
- url(r'^nullableonetoonesource/(?P[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
+ url(r"^dummyurl/(?P[0-9]+)/$", dummy_view, name="dummy-url"),
+ url(
+ r"^manytomanysource/(?P[0-9]+)/$",
+ dummy_view,
+ name="manytomanysource-detail",
+ ),
+ url(
+ r"^manytomanytarget/(?P[0-9]+)/$",
+ dummy_view,
+ name="manytomanytarget-detail",
+ ),
+ url(
+ r"^foreignkeysource/(?P[0-9]+)/$",
+ dummy_view,
+ name="foreignkeysource-detail",
+ ),
+ url(
+ r"^foreignkeytarget/(?P[0-9]+)/$",
+ dummy_view,
+ name="foreignkeytarget-detail",
+ ),
+ url(
+ r"^nullableforeignkeysource/(?P[0-9]+)/$",
+ dummy_view,
+ name="nullableforeignkeysource-detail",
+ ),
+ url(r"^onetoonetarget/(?P[0-9]+)/$", dummy_view, name="onetoonetarget-detail"),
+ url(
+ r"^nullableonetoonesource/(?P[0-9]+)/$",
+ dummy_view,
+ name="nullableonetoonesource-detail",
+ ),
]
@@ -34,237 +64,505 @@ urlpatterns = [
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManyTarget
- fields = ('url', 'name', 'sources')
+ fields = ("url", "name", "sources")
class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ManyToManySource
- fields = ('url', 'name', 'targets')
+ fields = ("url", "name", "targets")
# ForeignKey
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeyTarget
- fields = ('url', 'name', 'sources')
+ fields = ("url", "name", "sources")
class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = ForeignKeySource
- fields = ('url', 'name', 'target')
+ fields = ("url", "name", "target")
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
- fields = ('url', 'name', 'target')
+ fields = ("url", "name", "target")
# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = OneToOneTarget
- fields = ('url', 'name', 'nullable_source')
+ fields = ("url", "name", "nullable_source")
# TODO: Add test that .data cannot be accessed prior to .is_valid
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+@override_settings(ROOT_URLCONF="tests.test_relations_hyperlink")
class HyperlinkedManyToManyTests(TestCase):
def setUp(self):
for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
+ target = ManyToManyTarget(name="target-%d" % idx)
target.save()
- source = ManyToManySource(name='source-%d' % idx)
+ source = ManyToManySource(name="source-%d" % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
+ serializer = ManyToManySourceSerializer(
+ queryset, many=True, context={"request": None}
+ )
expected = [
- {'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
- {'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
- {'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
+ {
+ "url": "/manytomanysource/1/",
+ "name": "source-1",
+ "targets": ["/manytomanytarget/1/"],
+ },
+ {
+ "url": "/manytomanysource/2/",
+ "name": "source-2",
+ "targets": ["/manytomanytarget/1/", "/manytomanytarget/2/"],
+ },
+ {
+ "url": "/manytomanysource/3/",
+ "name": "source-3",
+ "targets": [
+ "/manytomanytarget/1/",
+ "/manytomanytarget/2/",
+ "/manytomanytarget/3/",
+ ],
+ },
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ {
+ "url": "http://testserver/manytomanysource/1/",
+ "name": "source-1",
+ "targets": ["http://testserver/manytomanytarget/1/"],
+ },
+ {
+ "url": "http://testserver/manytomanysource/2/",
+ "name": "source-2",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanysource/3/",
+ "name": "source-3",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ },
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
- queryset = ManyToManySource.objects.all().prefetch_related('targets')
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ queryset = ManyToManySource.objects.all().prefetch_related("targets")
+ serializer = ManyToManySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
with self.assertNumQueries(2):
serializer.data
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
+ {
+ "url": "http://testserver/manytomanytarget/1/",
+ "name": "target-1",
+ "sources": [
+ "http://testserver/manytomanysource/1/",
+ "http://testserver/manytomanysource/2/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/2/",
+ "name": "target-2",
+ "sources": [
+ "http://testserver/manytomanysource/2/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/3/",
+ "name": "target-3",
+ "sources": ["http://testserver/manytomanysource/3/"],
+ },
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ data = {
+ "url": "http://testserver/manytomanysource/1/",
+ "name": "source-1",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ }
instance = ManyToManySource.objects.get(pk=1)
- serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
+ serializer = ManyToManySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
+ {
+ "url": "http://testserver/manytomanysource/1/",
+ "name": "source-1",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanysource/2/",
+ "name": "source-2",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanysource/3/",
+ "name": "source-3",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ },
]
assert serializer.data == expected
def test_reverse_many_to_many_update(self):
- data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
+ data = {
+ "url": "http://testserver/manytomanytarget/1/",
+ "name": "target-1",
+ "sources": ["http://testserver/manytomanysource/1/"],
+ }
instance = ManyToManyTarget.objects.get(pk=1)
- serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
+ serializer = ManyToManyTargetSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
-
+ {
+ "url": "http://testserver/manytomanytarget/1/",
+ "name": "target-1",
+ "sources": ["http://testserver/manytomanysource/1/"],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/2/",
+ "name": "target-2",
+ "sources": [
+ "http://testserver/manytomanysource/2/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/3/",
+ "name": "target-3",
+ "sources": ["http://testserver/manytomanysource/3/"],
+ },
]
assert serializer.data == expected
def test_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
- serializer = ManyToManySourceSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/manytomanysource/4/",
+ "name": "source-4",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ }
+ serializer = ManyToManySourceSerializer(data=data, context={"request": request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
- serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
- {'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
- {'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
- {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
+ {
+ "url": "http://testserver/manytomanysource/1/",
+ "name": "source-1",
+ "targets": ["http://testserver/manytomanytarget/1/"],
+ },
+ {
+ "url": "http://testserver/manytomanysource/2/",
+ "name": "source-2",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanysource/3/",
+ "name": "source-3",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/2/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanysource/4/",
+ "name": "source-4",
+ "targets": [
+ "http://testserver/manytomanytarget/1/",
+ "http://testserver/manytomanytarget/3/",
+ ],
+ },
]
assert serializer.data == expected
def test_reverse_many_to_many_create(self):
- data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
- serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/manytomanytarget/4/",
+ "name": "target-4",
+ "sources": [
+ "http://testserver/manytomanysource/1/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ }
+ serializer = ManyToManyTargetSerializer(data=data, context={"request": request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'target-4'
+ assert obj.name == "target-4"
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
- serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ManyToManyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
- {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
+ {
+ "url": "http://testserver/manytomanytarget/1/",
+ "name": "target-1",
+ "sources": [
+ "http://testserver/manytomanysource/1/",
+ "http://testserver/manytomanysource/2/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/2/",
+ "name": "target-2",
+ "sources": [
+ "http://testserver/manytomanysource/2/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/3/",
+ "name": "target-3",
+ "sources": ["http://testserver/manytomanysource/3/"],
+ },
+ {
+ "url": "http://testserver/manytomanytarget/4/",
+ "name": "target-4",
+ "sources": [
+ "http://testserver/manytomanysource/1/",
+ "http://testserver/manytomanysource/3/",
+ ],
+ },
]
assert serializer.data == expected
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+@override_settings(ROOT_URLCONF="tests.test_relations_hyperlink")
class HyperlinkedForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
- new_target = ForeignKeyTarget(name='target-2')
+ new_target = ForeignKeyTarget(name="target-2")
new_target.save()
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source = ForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/3/",
+ "name": "source-3",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
]
with self.assertNumQueries(1):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {
+ "url": "http://testserver/foreignkeytarget/1/",
+ "name": "target-1",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/2/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/foreignkeytarget/2/",
+ "name": "target-2",
+ "sources": [],
+ },
]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
+ data = {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/2/",
+ }
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
+ {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/2/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/3/",
+ "name": "source-3",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
+ data = {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": 2,
+ }
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Incorrect type. Expected URL string, received int.']}
+ assert serializer.errors == {
+ "target": ["Incorrect type. Expected URL string, received int."]
+ }
def test_reverse_foreign_key_update(self):
- data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
+ data = {
+ "url": "http://testserver/foreignkeytarget/2/",
+ "name": "target-2",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ }
instance = ForeignKeyTarget.objects.get(pk=2)
- serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
+ serializer = ForeignKeyTargetSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
- new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ new_serializer = ForeignKeyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
+ {
+ "url": "http://testserver/foreignkeytarget/1/",
+ "name": "target-1",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/2/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ },
+ {
+ "url": "http://testserver/foreignkeytarget/2/",
+ "name": "target-2",
+ "sources": [],
+ },
]
assert new_serializer.data == expected
@@ -273,95 +571,198 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ {
+ "url": "http://testserver/foreignkeytarget/1/",
+ "name": "target-1",
+ "sources": ["http://testserver/foreignkeysource/2/"],
+ },
+ {
+ "url": "http://testserver/foreignkeytarget/2/",
+ "name": "target-2",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ },
]
assert serializer.data == expected
def test_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
- serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/foreignkeysource/4/",
+ "name": "source-4",
+ "target": "http://testserver/foreignkeytarget/2/",
+ }
+ serializer = ForeignKeySourceSerializer(data=data, context={"request": request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
- serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
+ {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/3/",
+ "name": "source-3",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/foreignkeysource/4/",
+ "name": "source-4",
+ "target": "http://testserver/foreignkeytarget/2/",
+ },
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
- data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
- serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/foreignkeytarget/3/",
+ "name": "target-3",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ }
+ serializer = ForeignKeyTargetSerializer(data=data, context={"request": request})
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'target-3'
+ assert obj.name == "target-3"
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
- serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = ForeignKeyTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
- {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
- {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
+ {
+ "url": "http://testserver/foreignkeytarget/1/",
+ "name": "target-1",
+ "sources": ["http://testserver/foreignkeysource/2/"],
+ },
+ {
+ "url": "http://testserver/foreignkeytarget/2/",
+ "name": "target-2",
+ "sources": [],
+ },
+ {
+ "url": "http://testserver/foreignkeytarget/3/",
+ "name": "target-3",
+ "sources": [
+ "http://testserver/foreignkeysource/1/",
+ "http://testserver/foreignkeysource/3/",
+ ],
+ },
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
- data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
+ data = {
+ "url": "http://testserver/foreignkeysource/1/",
+ "name": "source-1",
+ "target": None,
+ }
instance = ForeignKeySource.objects.get(pk=1)
- serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ serializer = ForeignKeySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
+ assert serializer.errors == {"target": ["This field may not be null."]}
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+@override_settings(ROOT_URLCONF="tests.test_relations_hyperlink")
class HyperlinkedNullableForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source = NullableForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/3/",
+ "name": "source-3",
+ "target": None,
+ },
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/nullableforeignkeysource/4/",
+ "name": "source-4",
+ "target": None,
+ }
+ serializer = NullableForeignKeySourceSerializer(
+ data=data, context={"request": request}
+ )
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/3/",
+ "name": "source-3",
+ "target": None,
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/4/",
+ "name": "source-4",
+ "target": None,
+ },
]
assert serializer.data == expected
@@ -370,40 +771,88 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
- serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
+ data = {
+ "url": "http://testserver/nullableforeignkeysource/4/",
+ "name": "source-4",
+ "target": "",
+ }
+ expected_data = {
+ "url": "http://testserver/nullableforeignkeysource/4/",
+ "name": "source-4",
+ "target": None,
+ }
+ serializer = NullableForeignKeySourceSerializer(
+ data=data, context={"request": request}
+ )
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
+ {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/3/",
+ "name": "source-3",
+ "target": None,
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/4/",
+ "name": "source-4",
+ "target": None,
+ },
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ data = {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": None,
+ }
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
serializer.save()
assert serializer.data == data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": None,
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/3/",
+ "name": "source-3",
+ "target": None,
+ },
]
assert serializer.data == expected
@@ -412,40 +861,74 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
- expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
+ data = {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": "",
+ }
+ expected_data = {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": None,
+ }
instance = NullableForeignKeySource.objects.get(pk=1)
- serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ instance, data=data, context={"request": request}
+ )
assert serializer.is_valid()
serializer.save()
assert serializer.data == expected_data
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
- serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableForeignKeySourceSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
- {'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
- {'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
+ {
+ "url": "http://testserver/nullableforeignkeysource/1/",
+ "name": "source-1",
+ "target": None,
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/2/",
+ "name": "source-2",
+ "target": "http://testserver/foreignkeytarget/1/",
+ },
+ {
+ "url": "http://testserver/nullableforeignkeysource/3/",
+ "name": "source-3",
+ "target": None,
+ },
]
assert serializer.data == expected
-@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
+@override_settings(ROOT_URLCONF="tests.test_relations_hyperlink")
class HyperlinkedNullableOneToOneTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ target = OneToOneTarget(name="target-1")
target.save()
- new_target = OneToOneTarget(name='target-2')
+ new_target = OneToOneTarget(name="target-2")
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=target)
+ source = NullableOneToOneSource(name="source-1", target=target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
- serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
+ serializer = NullableOneToOneTargetSerializer(
+ queryset, many=True, context={"request": request}
+ )
expected = [
- {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
- {'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
+ {
+ "url": "http://testserver/onetoonetarget/1/",
+ "name": "target-1",
+ "nullable_source": "http://testserver/nullableonetoonesource/1/",
+ },
+ {
+ "url": "http://testserver/onetoonetarget/2/",
+ "name": "target-2",
+ "nullable_source": None,
+ },
]
assert serializer.data == expected
diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py
index 2cffb62e6..212471dbc 100644
--- a/tests/test_relations_pk.py
+++ b/tests/test_relations_pk.py
@@ -5,11 +5,18 @@ from django.utils import six
from rest_framework import serializers
from tests.models import (
- ForeignKeySource, ForeignKeySourceWithLimitedChoices,
- ForeignKeySourceWithQLimitedChoices, ForeignKeyTarget, ManyToManySource,
- ManyToManyTarget, NullableForeignKeySource, NullableOneToOneSource,
- NullableUUIDForeignKeySource, OneToOnePKSource, OneToOneTarget,
- UUIDForeignKeyTarget
+ ForeignKeySource,
+ ForeignKeySourceWithLimitedChoices,
+ ForeignKeySourceWithQLimitedChoices,
+ ForeignKeyTarget,
+ ManyToManySource,
+ ManyToManyTarget,
+ NullableForeignKeySource,
+ NullableOneToOneSource,
+ NullableUUIDForeignKeySource,
+ OneToOnePKSource,
+ OneToOneTarget,
+ UUIDForeignKeyTarget,
)
@@ -17,26 +24,26 @@ from tests.models import (
class ManyToManyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManyTarget
- fields = ('id', 'name', 'sources')
+ fields = ("id", "name", "sources")
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
- fields = ('id', 'name', 'targets')
+ fields = ("id", "name", "targets")
# ForeignKey
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeyTarget
- fields = ('id', 'name', 'sources')
+ fields = ("id", "name", "sources")
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
- fields = ('id', 'name', 'target')
+ fields = ("id", "name", "target")
class ForeignKeySourceWithLimitedChoicesSerializer(serializers.ModelSerializer):
@@ -49,7 +56,7 @@ class ForeignKeySourceWithLimitedChoicesSerializer(serializers.ModelSerializer):
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = NullableForeignKeySource
- fields = ('id', 'name', 'target')
+ fields = ("id", "name", "target")
# Nullable UUIDForeignKey
@@ -57,35 +64,36 @@ class NullableUUIDForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.PrimaryKeyRelatedField(
pk_field=serializers.UUIDField(),
queryset=UUIDForeignKeyTarget.objects.all(),
- allow_null=True)
+ allow_null=True,
+ )
class Meta:
model = NullableUUIDForeignKeySource
- fields = ('id', 'name', 'target')
+ fields = ("id", "name", "target")
# Nullable OneToOne
class NullableOneToOneTargetSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneTarget
- fields = ('id', 'name', 'nullable_source')
+ fields = ("id", "name", "nullable_source")
class OneToOnePKSourceSerializer(serializers.ModelSerializer):
-
class Meta:
model = OneToOnePKSource
- fields = '__all__'
+ fields = "__all__"
# TODO: Add test that .data cannot be accessed prior to .is_valid
+
class PKManyToManyTests(TestCase):
def setUp(self):
for idx in range(1, 4):
- target = ManyToManyTarget(name='target-%d' % idx)
+ target = ManyToManyTarget(name="target-%d" % idx)
target.save()
- source = ManyToManySource(name='source-%d' % idx)
+ source = ManyToManySource(name="source-%d" % idx)
source.save()
for target in ManyToManyTarget.objects.all():
source.targets.add(target)
@@ -94,15 +102,15 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ {"id": 1, "name": "source-1", "targets": [1]},
+ {"id": 2, "name": "source-2", "targets": [1, 2]},
+ {"id": 3, "name": "source-3", "targets": [1, 2, 3]},
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_retrieve_prefetch_related(self):
- queryset = ManyToManySource.objects.all().prefetch_related('targets')
+ queryset = ManyToManySource.objects.all().prefetch_related("targets")
serializer = ManyToManySourceSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
@@ -111,15 +119,15 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
+ {"id": 1, "name": "target-1", "sources": [1, 2, 3]},
+ {"id": 2, "name": "target-2", "sources": [2, 3]},
+ {"id": 3, "name": "target-3", "sources": [3]},
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_many_to_many_update(self):
- data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
+ data = {"id": 1, "name": "source-1", "targets": [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -130,14 +138,14 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
+ {"id": 1, "name": "source-1", "targets": [1, 2, 3]},
+ {"id": 2, "name": "source-2", "targets": [1, 2]},
+ {"id": 3, "name": "source-3", "targets": [1, 2, 3]},
]
assert serializer.data == expected
def test_reverse_many_to_many_update(self):
- data = {'id': 1, 'name': 'target-1', 'sources': [1]}
+ data = {"id": 1, "name": "target-1", "sources": [1]}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data)
assert serializer.is_valid()
@@ -148,78 +156,78 @@ class PKManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]}
+ {"id": 1, "name": "target-1", "sources": [1]},
+ {"id": 2, "name": "target-2", "sources": [2, 3]},
+ {"id": 3, "name": "target-3", "sources": [3]},
]
assert serializer.data == expected
def test_many_to_many_create(self):
- data = {'id': 4, 'name': 'source-4', 'targets': [1, 3]}
+ data = {"id": 4, "name": "source-4", "targets": [1, 3]}
serializer = ManyToManySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'targets': [1]},
- {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
- {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]},
- {'id': 4, 'name': 'source-4', 'targets': [1, 3]},
+ {"id": 1, "name": "source-1", "targets": [1]},
+ {"id": 2, "name": "source-2", "targets": [1, 2]},
+ {"id": 3, "name": "source-3", "targets": [1, 2, 3]},
+ {"id": 4, "name": "source-4", "targets": [1, 3]},
]
assert serializer.data == expected
def test_many_to_many_unsaved(self):
- source = ManyToManySource(name='source-unsaved')
+ source = ManyToManySource(name="source-unsaved")
serializer = ManyToManySourceSerializer(source)
- expected = {'id': None, 'name': 'source-unsaved', 'targets': []}
+ expected = {"id": None, "name": "source-unsaved", "targets": []}
# no query if source hasn't been created yet
with self.assertNumQueries(0):
assert serializer.data == expected
def test_reverse_many_to_many_create(self):
- data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ data = {"id": 4, "name": "target-4", "sources": [1, 3]}
serializer = ManyToManyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'target-4'
+ assert obj.name == "target-4"
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
- {'id': 3, 'name': 'target-3', 'sources': [3]},
- {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
+ {"id": 1, "name": "target-1", "sources": [1, 2, 3]},
+ {"id": 2, "name": "target-2", "sources": [2, 3]},
+ {"id": 3, "name": "target-3", "sources": [3]},
+ {"id": 4, "name": "target-4", "sources": [1, 3]},
]
assert serializer.data == expected
class PKForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
- new_target = ForeignKeyTarget(name='target-2')
+ new_target = ForeignKeyTarget(name="target-2")
new_target.save()
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source = ForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
+ {"id": 1, "name": "source-1", "target": 1},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": 1},
]
with self.assertNumQueries(1):
assert serializer.data == expected
@@ -228,20 +236,20 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
+ {"id": 1, "name": "target-1", "sources": [1, 2, 3]},
+ {"id": 2, "name": "target-2", "sources": []},
]
with self.assertNumQueries(3):
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
- queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ queryset = ForeignKeyTarget.objects.all().prefetch_related("sources")
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 2}
+ data = {"id": 1, "name": "source-1", "target": 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -252,21 +260,26 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 2},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1}
+ {"id": 1, "name": "source-1", "target": 2},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": 1},
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'foo'}
+ data = {"id": 1, "name": "source-1", "target": "foo"}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}
+ assert serializer.errors == {
+ "target": [
+ "Incorrect type. Expected pk value, received %s."
+ % six.text_type.__name__
+ ]
+ }
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
+ data = {"id": 2, "name": "target-2", "sources": [1, 3]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
@@ -275,8 +288,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
- {'id': 2, 'name': 'target-2', 'sources': []},
+ {"id": 1, "name": "target-1", "sources": [1, 2, 3]},
+ {"id": 2, "name": "target-2", "sources": []},
]
assert new_serializer.data == expected
@@ -287,58 +300,58 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': [1, 3]},
+ {"id": 1, "name": "target-1", "sources": [2]},
+ {"id": 2, "name": "target-2", "sources": [1, 3]},
]
assert serializer.data == expected
def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 2}
+ data = {"id": 4, "name": "source-4", "target": 2}
serializer = ForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': 1},
- {'id': 4, 'name': 'source-4', 'target': 2},
+ {"id": 1, "name": "source-1", "target": 1},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": 1},
+ {"id": 4, "name": "source-4", "target": 2},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': [1, 3]}
+ data = {"id": 3, "name": "target-3", "sources": [1, 3]}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'target-3'
+ assert obj.name == "target-3"
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': [2]},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': [1, 3]},
+ {"id": 1, "name": "target-1", "sources": [2]},
+ {"id": 2, "name": "target-2", "sources": []},
+ {"id": 3, "name": "target-3", "sources": [1, 3]},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
+ assert serializer.errors == {"target": ["This field may not be null."]}
def test_foreign_key_with_unsaved(self):
- source = ForeignKeySource(name='source-unsaved')
- expected = {'id': None, 'name': 'source-unsaved', 'target': None}
+ source = ForeignKeySource(name="source-unsaved")
+ expected = {"id": None, "name": "source-unsaved", "target": None}
serializer = ForeignKeySourceSerializer(source)
@@ -353,19 +366,21 @@ class PKForeignKeyTests(TestCase):
https://github.com/encode/django-rest-framework/issues/1072
"""
serializer = NullableForeignKeySourceSerializer()
- assert serializer.data['target'] is None
+ assert serializer.data["target"] is None
def test_foreign_key_not_required(self):
"""
Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required.
"""
+
class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta):
- extra_kwargs = {'target': {'required': False}}
- serializer = ModelSerializer(data={'name': 'test'})
+ extra_kwargs = {"target": {"required": False}}
+
+ serializer = ModelSerializer(data={"name": "test"})
serializer.is_valid(raise_exception=True)
- assert 'target' not in serializer.validated_data
+ assert "target" not in serializer.validated_data
def test_queryset_size_without_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target")
@@ -376,7 +391,11 @@ class PKForeignKeyTests(TestCase):
def test_queryset_size_with_limited_choices(self):
limited_target = ForeignKeyTarget(name="limited-target")
limited_target.save()
- queryset = ForeignKeySourceWithLimitedChoicesSerializer().fields["target"].get_queryset()
+ queryset = (
+ ForeignKeySourceWithLimitedChoicesSerializer()
+ .fields["target"]
+ .get_queryset()
+ )
assert len(queryset) == 1
def test_queryset_size_with_Q_limited_choices(self):
@@ -394,40 +413,40 @@ class PKForeignKeyTests(TestCase):
class PKNullableForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source = NullableForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {"id": 1, "name": "source-1", "target": 1},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
+ data = {"id": 4, "name": "source-4", "target": None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
+ {"id": 1, "name": "source-1", "target": 1},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": None},
+ {"id": 4, "name": "source-4", "target": None},
]
assert serializer.data == expected
@@ -436,27 +455,27 @@ class PKNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ data = {"id": 4, "name": "source-4", "target": ""}
+ expected_data = {"id": 4, "name": "source-4", "target": None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 1},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
+ {"id": 1, "name": "source-1", "target": 1},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": None},
+ {"id": 4, "name": "source-4", "target": None},
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -467,9 +486,9 @@ class PKNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
+ {"id": 1, "name": "source-1", "target": None},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
@@ -478,8 +497,8 @@ class PKNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": ""}
+ expected_data = {"id": 1, "name": "source-1", "target": None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -490,14 +509,14 @@ class PKNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 1},
- {'id': 3, 'name': 'source-3', 'target': None}
+ {"id": 1, "name": "source-1", "target": None},
+ {"id": 2, "name": "source-2", "target": 1},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
def test_null_uuid_foreign_key_serializes_as_none(self):
- source = NullableUUIDForeignKeySource(name='Source')
+ source = NullableUUIDForeignKeySource(name="Source")
serializer = NullableUUIDForeignKeySourceSerializer(source)
data = serializer.data
assert data["target"] is None
@@ -510,39 +529,44 @@ class PKNullableForeignKeyTests(TestCase):
class PKNullableOneToOneTests(TestCase):
def setUp(self):
- target = OneToOneTarget(name='target-1')
+ target = OneToOneTarget(name="target-1")
target.save()
- new_target = OneToOneTarget(name='target-2')
+ new_target = OneToOneTarget(name="target-2")
new_target.save()
- source = NullableOneToOneSource(name='source-1', target=new_target)
+ source = NullableOneToOneSource(name="source-1", target=new_target)
source.save()
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'nullable_source': None},
- {'id': 2, 'name': 'target-2', 'nullable_source': 1},
+ {"id": 1, "name": "target-1", "nullable_source": None},
+ {"id": 2, "name": "target-2", "nullable_source": 1},
]
assert serializer.data == expected
class OneToOnePrimaryKeyTests(TestCase):
-
def setUp(self):
# Given: Some target models already exist
- self.target = target = OneToOneTarget(name='target-1')
+ self.target = target = OneToOneTarget(name="target-1")
target.save()
- self.alt_target = alt_target = OneToOneTarget(name='target-2')
+ self.alt_target = alt_target = OneToOneTarget(name="target-2")
alt_target.save()
def test_one_to_one_when_primary_key(self):
# When: Creating a Source pointing at the id of the second Target
target_pk = self.alt_target.id
- source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
+ source = OneToOnePKSourceSerializer(
+ data={"name": "source-2", "target": target_pk}
+ )
# Then: The source is valid with the serializer
if not source.is_valid():
- self.fail("Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(source.errors))
+ self.fail(
+ "Expected OneToOnePKTargetSerializer to be valid but had errors: {}".format(
+ source.errors
+ )
+ )
# Then: Saving the serializer creates a new object
new_source = source.save()
# Then: The new object has the same pk as the target object
@@ -551,7 +575,7 @@ class OneToOnePrimaryKeyTests(TestCase):
def test_one_to_one_when_primary_key_no_duplicates(self):
# When: Creating a Source pointing at the id of the second Target
target_pk = self.target.id
- data = {'name': 'source-1', 'target': target_pk}
+ data = {"name": "source-1", "target": target_pk}
source = OneToOnePKSourceSerializer(data=data)
# Then: The source is valid with the serializer
self.assertTrue(source.is_valid())
@@ -562,14 +586,16 @@ class OneToOnePrimaryKeyTests(TestCase):
# When: Trying to create a second object
second_source = OneToOnePKSourceSerializer(data=data)
self.assertFalse(second_source.is_valid())
- expected = {'target': [u'one to one pk source with this target already exists.']}
+ expected = {"target": ["one to one pk source with this target already exists."]}
self.assertDictEqual(second_source.errors, expected)
def test_one_to_one_when_primary_key_does_not_exist(self):
# Given: a target PK that does not exist
target_pk = self.target.pk + self.alt_target.pk
- source = OneToOnePKSourceSerializer(data={'name': 'source-2', 'target': target_pk})
+ source = OneToOnePKSourceSerializer(
+ data={"name": "source-2", "target": target_pk}
+ )
# Then: The source is not valid with the serializer
self.assertFalse(source.is_valid())
- self.assertIn("Invalid pk", source.errors['target'][0])
- self.assertIn("object does not exist", source.errors['target'][0])
+ self.assertIn("Invalid pk", source.errors["target"][0])
+ self.assertIn("object does not exist", source.errors["target"][0])
diff --git a/tests/test_relations_slug.py b/tests/test_relations_slug.py
index 0b9ca79d3..c1a1ff33f 100644
--- a/tests/test_relations_slug.py
+++ b/tests/test_relations_slug.py
@@ -1,70 +1,63 @@
from django.test import TestCase
from rest_framework import serializers
-from tests.models import (
- ForeignKeySource, ForeignKeyTarget, NullableForeignKeySource
-)
+from tests.models import ForeignKeySource, ForeignKeyTarget, NullableForeignKeySource
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
sources = serializers.SlugRelatedField(
- slug_field='name',
- queryset=ForeignKeySource.objects.all(),
- many=True
+ slug_field="name", queryset=ForeignKeySource.objects.all(), many=True
)
class Meta:
model = ForeignKeyTarget
- fields = '__all__'
+ fields = "__all__"
class ForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(
- slug_field='name',
- queryset=ForeignKeyTarget.objects.all()
+ slug_field="name", queryset=ForeignKeyTarget.objects.all()
)
class Meta:
model = ForeignKeySource
- fields = '__all__'
+ fields = "__all__"
class NullableForeignKeySourceSerializer(serializers.ModelSerializer):
target = serializers.SlugRelatedField(
- slug_field='name',
- queryset=ForeignKeyTarget.objects.all(),
- allow_null=True
+ slug_field="name", queryset=ForeignKeyTarget.objects.all(), allow_null=True
)
class Meta:
model = NullableForeignKeySource
- fields = '__all__'
+ fields = "__all__"
# TODO: M2M Tests, FKTests (Non-nullable), One2One
class SlugForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
- new_target = ForeignKeyTarget(name='target-2')
+ new_target = ForeignKeyTarget(name="target-2")
new_target.save()
for idx in range(1, 4):
- source = ForeignKeySource(name='source-%d' % idx, target=target)
+ source = ForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ {"id": 1, "name": "source-1", "target": "target-1"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": "target-1"},
]
with self.assertNumQueries(4):
assert serializer.data == expected
def test_foreign_key_retrieve_select_related(self):
- queryset = ForeignKeySource.objects.all().select_related('target')
+ queryset = ForeignKeySource.objects.all().select_related("target")
serializer = ForeignKeySourceSerializer(queryset, many=True)
with self.assertNumQueries(1):
serializer.data
@@ -73,19 +66,23 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
+ {
+ "id": 1,
+ "name": "target-1",
+ "sources": ["source-1", "source-2", "source-3"],
+ },
+ {"id": 2, "name": "target-2", "sources": []},
]
assert serializer.data == expected
def test_reverse_foreign_key_retrieve_prefetch_related(self):
- queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
+ queryset = ForeignKeyTarget.objects.all().prefetch_related("sources")
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data
def test_foreign_key_update(self):
- data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
+ data = {"id": 1, "name": "source-1", "target": "target-2"}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -96,21 +93,21 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-2'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'}
+ {"id": 1, "name": "source-1", "target": "target-2"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": "target-1"},
]
assert serializer.data == expected
def test_foreign_key_update_incorrect_type(self):
- data = {'id': 1, 'name': 'source-1', 'target': 123}
+ data = {"id": 1, "name": "source-1", "target": 123}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['Object with name=123 does not exist.']}
+ assert serializer.errors == {"target": ["Object with name=123 does not exist."]}
def test_reverse_foreign_key_update(self):
- data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
+ data = {"id": 2, "name": "target-2", "sources": ["source-1", "source-3"]}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
assert serializer.is_valid()
@@ -119,8 +116,12 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
- {'id': 2, 'name': 'target-2', 'sources': []},
+ {
+ "id": 1,
+ "name": "target-1",
+ "sources": ["source-1", "source-2", "source-3"],
+ },
+ {"id": 2, "name": "target-2", "sources": []},
]
assert new_serializer.data == expected
@@ -131,93 +132,93 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']},
+ {"id": 1, "name": "target-1", "sources": ["source-2"]},
+ {"id": 2, "name": "target-2", "sources": ["source-1", "source-3"]},
]
assert serializer.data == expected
def test_foreign_key_create(self):
- data = {'id': 4, 'name': 'source-4', 'target': 'target-2'}
+ data = {"id": 4, "name": "source-4", "target": "target-2"}
serializer = ForeignKeySourceSerializer(data=data)
serializer.is_valid()
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is added, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': 'target-1'},
- {'id': 4, 'name': 'source-4', 'target': 'target-2'},
+ {"id": 1, "name": "source-1", "target": "target-1"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": "target-1"},
+ {"id": 4, "name": "source-4", "target": "target-2"},
]
assert serializer.data == expected
def test_reverse_foreign_key_create(self):
- data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
+ data = {"id": 3, "name": "target-3", "sources": ["source-1", "source-3"]}
serializer = ForeignKeyTargetSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'target-3'
+ assert obj.name == "target-3"
# Ensure target 3 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'target-1', 'sources': ['source-2']},
- {'id': 2, 'name': 'target-2', 'sources': []},
- {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']},
+ {"id": 1, "name": "target-1", "sources": ["source-2"]},
+ {"id": 2, "name": "target-2", "sources": []},
+ {"id": 3, "name": "target-3", "sources": ["source-1", "source-3"]},
]
assert serializer.data == expected
def test_foreign_key_update_with_invalid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'target': ['This field may not be null.']}
+ assert serializer.errors == {"target": ["This field may not be null."]}
class SlugNullableForeignKeyTests(TestCase):
def setUp(self):
- target = ForeignKeyTarget(name='target-1')
+ target = ForeignKeyTarget(name="target-1")
target.save()
for idx in range(1, 4):
if idx == 3:
target = None
- source = NullableForeignKeySource(name='source-%d' % idx, target=target)
+ source = NullableForeignKeySource(name="source-%d" % idx, target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
+ {"id": 1, "name": "source-1", "target": "target-1"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
def test_foreign_key_create_with_valid_null(self):
- data = {'id': 4, 'name': 'source-4', 'target': None}
+ data = {"id": 4, "name": "source-4", "target": None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
+ {"id": 1, "name": "source-1", "target": "target-1"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": None},
+ {"id": 4, "name": "source-4", "target": None},
]
assert serializer.data == expected
@@ -226,27 +227,27 @@ class SlugNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 4, 'name': 'source-4', 'target': ''}
- expected_data = {'id': 4, 'name': 'source-4', 'target': None}
+ data = {"id": 4, "name": "source-4", "target": ""}
+ expected_data = {"id": 4, "name": "source-4", "target": None}
serializer = NullableForeignKeySourceSerializer(data=data)
assert serializer.is_valid()
obj = serializer.save()
assert serializer.data == expected_data
- assert obj.name == 'source-4'
+ assert obj.name == "source-4"
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': 'target-1'},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None},
- {'id': 4, 'name': 'source-4', 'target': None}
+ {"id": 1, "name": "source-1", "target": "target-1"},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": None},
+ {"id": 4, "name": "source-4", "target": None},
]
assert serializer.data == expected
def test_foreign_key_update_with_valid_null(self):
- data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -257,9 +258,9 @@ class SlugNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
+ {"id": 1, "name": "source-1", "target": None},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
@@ -268,8 +269,8 @@ class SlugNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
- data = {'id': 1, 'name': 'source-1', 'target': ''}
- expected_data = {'id': 1, 'name': 'source-1', 'target': None}
+ data = {"id": 1, "name": "source-1", "target": ""}
+ expected_data = {"id": 1, "name": "source-1", "target": None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data)
assert serializer.is_valid()
@@ -280,8 +281,8 @@ class SlugNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True)
expected = [
- {'id': 1, 'name': 'source-1', 'target': None},
- {'id': 2, 'name': 'source-2', 'target': 'target-1'},
- {'id': 3, 'name': 'source-3', 'target': None}
+ {"id": 1, "name": "source-1", "target": None},
+ {"id": 2, "name": "source-2", "target": "target-1"},
+ {"id": 3, "name": "source-3", "target": None},
]
assert serializer.data == expected
diff --git a/tests/test_renderers.py b/tests/test_renderers.py
index 60a0c0307..4196a07f3 100644
--- a/tests/test_renderers.py
+++ b/tests/test_renderers.py
@@ -19,8 +19,14 @@ from rest_framework import permissions, serializers, status
from rest_framework.compat import MutableMapping, coreapi
from rest_framework.decorators import action
from rest_framework.renderers import (
- AdminRenderer, BaseRenderer, BrowsableAPIRenderer, DocumentationRenderer,
- HTMLFormRenderer, JSONRenderer, SchemaJSRenderer, StaticHTMLRenderer
+ AdminRenderer,
+ BaseRenderer,
+ BrowsableAPIRenderer,
+ DocumentationRenderer,
+ HTMLFormRenderer,
+ JSONRenderer,
+ SchemaJSRenderer,
+ StaticHTMLRenderer,
)
from rest_framework.request import Request
from rest_framework.response import Response
@@ -31,25 +37,26 @@ from rest_framework.utils import json
from rest_framework.views import APIView
from rest_framework.viewsets import ViewSet
+
DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
+DUMMYCONTENT = "dummycontent"
def RENDERER_A_SERIALIZER(x):
- return ('Renderer A: %s' % x).encode('ascii')
+ return ("Renderer A: %s" % x).encode("ascii")
def RENDERER_B_SERIALIZER(x):
- return ('Renderer B: %s' % x).encode('ascii')
+ return ("Renderer B: %s" % x).encode("ascii")
expected_results = [
- ((elem for elem in [1, 2, 3]), JSONRenderer, b'[1,2,3]') # Generator
+ ((elem for elem in [1, 2, 3]), JSONRenderer, b"[1,2,3]") # Generator
]
class DummyTestModel(models.Model):
- name = models.CharField(max_length=42, default='')
+ name = models.CharField(max_length=42, default="")
class BasicRendererTests(TestCase):
@@ -60,7 +67,7 @@ class BasicRendererTests(TestCase):
class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
+ media_type = "mock/renderera"
format = "formata"
def render(self, data, media_type=None, renderer_context=None):
@@ -68,7 +75,7 @@ class RendererA(BaseRenderer):
class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
+ media_type = "mock/rendererb"
format = "formatb"
def render(self, data, media_type=None, renderer_context=None):
@@ -85,12 +92,12 @@ class MockView(APIView):
class MockGETView(APIView):
def get(self, request, **kwargs):
- return Response({'foo': ['bar', 'baz']})
+ return Response({"foo": ["bar", "baz"]})
class MockPOSTView(APIView):
def post(self, request, **kwargs):
- return Response({'foo': request.data})
+ return Response({"foo": request.data})
class EmptyGETView(APIView):
@@ -101,34 +108,40 @@ class EmptyGETView(APIView):
class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
+ renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs):
- return Response('text')
+ return Response("text")
class HTMLView1(APIView):
renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
def get(self, request, **kwargs):
- return Response('text')
+ return Response("text")
urlpatterns = [
- url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
- url(r'^cache$', MockGETView.as_view()),
- url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^empty$', EmptyGETView.as_view()),
- url(r'^api', include('rest_framework.urls', namespace='rest_framework'))
+ url(
+ r"^.*\.(?P.+)$",
+ MockView.as_view(renderer_classes=[RendererA, RendererB]),
+ ),
+ url(r"^$", MockView.as_view(renderer_classes=[RendererA, RendererB])),
+ url(r"^cache$", MockGETView.as_view()),
+ url(
+ r"^parseerror$",
+ MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer]),
+ ),
+ url(r"^html$", HTMLView.as_view()),
+ url(r"^html1$", HTMLView1.as_view()),
+ url(r"^empty$", EmptyGETView.as_view()),
+ url(r"^api", include("rest_framework.urls", namespace="rest_framework")),
]
class POSTDeniedPermission(permissions.BasePermission):
def has_permission(self, request, view):
- return request.method != 'POST'
+ return request.method != "POST"
class POSTDeniedView(APIView):
@@ -151,97 +164,96 @@ class POSTDeniedView(APIView):
class DocumentingRendererTests(TestCase):
def test_only_permitted_forms_are_displayed(self):
view = POSTDeniedView.as_view()
- request = APIRequestFactory().get('/')
+ request = APIRequestFactory().get("/")
response = view(request).render()
- self.assertNotContains(response, '>POST<')
- self.assertContains(response, '>PUT<')
- self.assertContains(response, '>PATCH<')
+ self.assertNotContains(response, ">POST<")
+ self.assertContains(response, ">PUT<")
+ self.assertContains(response, ">PATCH<")
-@override_settings(ROOT_URLCONF='tests.test_renderers')
+@override_settings(ROOT_URLCONF="tests.test_renderers")
class RendererEndToEndTests(TestCase):
"""
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
+
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/")
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
- resp = self.client.head('/')
+ resp = self.client.head("/")
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
+ self.assertEqual(resp.content, six.b(""))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT="*/*")
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
- resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
+ resp = self.client.get("/", HTTP_ACCEPT="foo/bar")
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ param = "?%s=%s" % (api_settings.URL_FORMAT_OVERRIDE, RendererB.format)
+ resp = self.client.get("/" + param)
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get("/something.formatb")
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
- param = '?%s=%s' % (
- api_settings.URL_FORMAT_OVERRIDE,
- RendererB.format
- )
- resp = self.client.get('/' + param,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ param = "?%s=%s" % (api_settings.URL_FORMAT_OVERRIDE, RendererB.format)
+ resp = self.client.get("/" + param, HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly."""
- resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ resp = self.client.post(
+ "/parseerror",
+ data="foobar",
+ content_type="application/json",
+ HTTP_ACCEPT="text/html",
+ )
+ self.assertEqual(resp["Content-Type"], "text/html; charset=utf-8")
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
def test_204_no_content_responses_have_no_content_type_set(self):
@@ -250,8 +262,8 @@ class RendererEndToEndTests(TestCase):
https://github.com/encode/django-rest-framework/issues/1196
"""
- resp = self.client.get('/empty')
- self.assertEqual(resp.get('Content-Type', None), None)
+ resp = self.client.get("/empty")
+ self.assertEqual(resp.get("Content-Type", None), None)
self.assertEqual(resp.status_code, status.HTTP_204_NO_CONTENT)
def test_contains_headers_of_api_response(self):
@@ -261,10 +273,10 @@ class RendererEndToEndTests(TestCase):
Test we display the headers of the API response and not those from the
HTML response
"""
- resp = self.client.get('/html1')
- self.assertContains(resp, '>GET, HEAD, OPTIONS<')
- self.assertContains(resp, '>application/json<')
- self.assertNotContains(resp, '>text/html; charset=utf-8<')
+ resp = self.client.get("/html1")
+ self.assertContains(resp, ">GET, HEAD, OPTIONS<")
+ self.assertContains(resp, ">application/json<")
+ self.assertNotContains(resp, ">text/html; charset=utf-8<")
_flat_repr = '{"foo":["bar","baz"]}'
@@ -276,19 +288,20 @@ def strip_trailing_whitespace(content):
Seems to be some inconsistencies re. trailing whitespace with
different versions of the json lib.
"""
- return re.sub(' +\n', '\n', content)
+ return re.sub(" +\n", "\n", content)
class BaseRendererTests(TestCase):
"""
Tests BaseRenderer
"""
+
def test_render_raise_error(self):
"""
BaseRenderer.render should raise NotImplementedError
"""
with pytest.raises(NotImplementedError):
- BaseRenderer().render('test')
+ BaseRenderer().render("test")
class JSONRendererTests(TestCase):
@@ -300,21 +313,21 @@ class JSONRendererTests(TestCase):
"""
JSONRenderer should deal with lazy translated strings.
"""
- ret = JSONRenderer().render(_('test'))
+ ret = JSONRenderer().render(_("test"))
self.assertEqual(ret, b'"test"')
def test_render_queryset_values(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values('id', 'name')
+ o = DummyTestModel.objects.create(name="dummy")
+ qs = DummyTestModel.objects.values("id", "name")
ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
- self.assertEqual(data, [{'id': o.id, 'name': o.name}])
+ data = json.loads(ret.decode("utf-8"))
+ self.assertEqual(data, [{"id": o.id, "name": o.name}])
def test_render_queryset_values_list(self):
- o = DummyTestModel.objects.create(name='dummy')
- qs = DummyTestModel.objects.values_list('id', 'name')
+ o = DummyTestModel.objects.create(name="dummy")
+ qs = DummyTestModel.objects.values_list("id", "name")
ret = JSONRenderer().render(qs)
- data = json.loads(ret.decode('utf-8'))
+ data = json.loads(ret.decode("utf-8"))
self.assertEqual(data, [[o.id, o.name]])
def test_render_dict_abc_obj(self):
@@ -341,11 +354,11 @@ class JSONRendererTests(TestCase):
return self._dict.keys()
x = Dict()
- x['key'] = 'string value'
+ x["key"] = "string value"
x[2] = 3
ret = JSONRenderer().render(x)
- data = json.loads(ret.decode('utf-8'))
- self.assertEqual(data, {'key': 'string value', '2': 3})
+ data = json.loads(ret.decode("utf-8"))
+ self.assertEqual(data, {"key": "string value", "2": 3})
def test_render_obj_with_getitem(self):
class DictLike(object):
@@ -359,7 +372,7 @@ class JSONRendererTests(TestCase):
return self._dict[key]
x = DictLike()
- x.set({'a': 1, 'b': 'string'})
+ x.set({"a": 1, "b": "string"})
with self.assertRaises(TypeError):
JSONRenderer().render(x)
@@ -367,81 +380,93 @@ class JSONRendererTests(TestCase):
renderer = JSONRenderer()
# Default to strict
- for value in [float('inf'), float('-inf'), float('nan')]:
+ for value in [float("inf"), float("-inf"), float("nan")]:
with pytest.raises(ValueError):
renderer.render(value)
renderer.strict = False
- assert renderer.render(float('inf')) == b'Infinity'
- assert renderer.render(float('-inf')) == b'-Infinity'
- assert renderer.render(float('nan')) == b'NaN'
+ assert renderer.render(float("inf")) == b"Infinity"
+ assert renderer.render(float("-inf")) == b"-Infinity"
+ assert renderer.render(float("nan")) == b"NaN"
def test_without_content_type_args(self):
"""
Test basic JSON rendering.
"""
- obj = {'foo': ['bar', 'baz']}
+ obj = {"foo": ["bar", "baz"]}
renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
+ content = renderer.render(obj, "application/json")
# Fix failing test case which depends on version of JSON library.
- self.assertEqual(content.decode('utf-8'), _flat_repr)
+ self.assertEqual(content.decode("utf-8"), _flat_repr)
def test_with_content_type_args(self):
"""
Test JSON rendering with additional content type arguments supplied.
"""
- obj = {'foo': ['bar', 'baz']}
+ obj = {"foo": ["bar", "baz"]}
renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json; indent=2')
- self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
+ content = renderer.render(obj, "application/json; indent=2")
+ self.assertEqual(
+ strip_trailing_whitespace(content.decode("utf-8")), _indented_repr
+ )
class UnicodeJSONRendererTests(TestCase):
"""
Tests specific for the Unicode JSON Renderer
"""
+
def test_proper_encoding(self):
- obj = {'countries': ['United Kingdom', 'France', 'España']}
+ obj = {"countries": ["United Kingdom", "France", "España"]}
renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8'))
+ content = renderer.render(obj, "application/json")
+ self.assertEqual(
+ content,
+ '{"countries":["United Kingdom","France","España"]}'.encode("utf-8"),
+ )
def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped,
# even when the non-escaping unicode representation is used.
# Regression test for #2169
- obj = {'should_escape': '\u2028\u2029'}
+ obj = {"should_escape": "\u2028\u2029"}
renderer = JSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8'))
+ content = renderer.render(obj, "application/json")
+ self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode("utf-8"))
class AsciiJSONRendererTests(TestCase):
"""
Tests specific for the Unicode JSON Renderer
"""
+
def test_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True
- obj = {'countries': ['United Kingdom', 'France', 'España']}
+
+ obj = {"countries": ["United Kingdom", "France", "España"]}
renderer = AsciiJSONRenderer()
- content = renderer.render(obj, 'application/json')
- self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8'))
+ content = renderer.render(obj, "application/json")
+ self.assertEqual(
+ content,
+ '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode("utf-8"),
+ )
# Tests for caching issue, #346
-@override_settings(ROOT_URLCONF='tests.test_renderers')
+@override_settings(ROOT_URLCONF="tests.test_renderers")
class CacheRenderTest(TestCase):
"""
Tests specific to caching responses
"""
+
def test_head_caching(self):
"""
Test caching of HEAD requests
"""
- response = self.client.head('/cache')
- cache.set('key', response)
- cached_response = cache.get('key')
+ response = self.client.head("/cache")
+ cache.set("key", response)
+ cached_response = cache.get("key")
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
@@ -450,9 +475,9 @@ class CacheRenderTest(TestCase):
"""
Test caching of GET requests
"""
- response = self.client.get('/cache')
- cache.set('key', response)
- cached_response = cache.get('key')
+ response = self.client.get("/cache")
+ cache.set("key", response)
+ cached_response = cache.get("key")
assert isinstance(cached_response, Response)
assert cached_response.content == response.content
assert cached_response.status_code == response.status_code
@@ -461,22 +486,22 @@ class CacheRenderTest(TestCase):
class TestJSONIndentationStyles:
def test_indented(self):
renderer = JSONRenderer()
- data = OrderedDict([('a', 1), ('b', 2)])
+ data = OrderedDict([("a", 1), ("b", 2)])
assert renderer.render(data) == b'{"a":1,"b":2}'
def test_compact(self):
renderer = JSONRenderer()
- data = OrderedDict([('a', 1), ('b', 2)])
- context = {'indent': 4}
+ data = OrderedDict([("a", 1), ("b", 2)])
+ context = {"indent": 4}
assert (
- renderer.render(data, renderer_context=context) ==
- b'{\n "a": 1,\n "b": 2\n}'
+ renderer.render(data, renderer_context=context)
+ == b'{\n "a": 1,\n "b": 2\n}'
)
def test_long_form(self):
renderer = JSONRenderer()
renderer.compact = False
- data = OrderedDict([('a', 1), ('b', 2)])
+ data = OrderedDict([("a", 1), ("b", 2)])
assert renderer.render(data) == b'{"a": 1, "b": 2}'
@@ -488,9 +513,9 @@ class TestHiddenFieldHTMLFormRenderer(TestCase):
serializer = TestSerializer(data={})
serializer.is_valid()
renderer = HTMLFormRenderer()
- field = serializer['published']
+ field = serializer["published"]
rendered = renderer.render_field(field, {})
- assert rendered == ''
+ assert rendered == ""
class TestHTMLFormRenderer(TestCase):
@@ -524,11 +549,10 @@ class TestChoiceFieldHTMLFormRenderer(TestCase):
"""
def setUp(self):
- choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
+ choices = ((1, "Option1"), (2, "Option2"), (12, "Option12"))
class TestSerializer(serializers.Serializer):
- test_field = serializers.ChoiceField(choices=choices,
- initial=2)
+ test_field = serializers.ChoiceField(choices=choices, initial=2)
self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer()
@@ -539,21 +563,19 @@ class TestChoiceFieldHTMLFormRenderer(TestCase):
self.assertIsInstance(result, SafeText)
- self.assertInHTML('',
- result)
+ self.assertInHTML('', result)
self.assertInHTML('', result)
self.assertInHTML('', result)
def test_render_selected_option(self):
- serializer = self.TestSerializer(data={'test_field': '12'})
+ serializer = self.TestSerializer(data={"test_field": "12"})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
- self.assertInHTML('',
- result)
+ self.assertInHTML('', result)
self.assertInHTML('', result)
self.assertInHTML('', result)
@@ -567,40 +589,42 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self):
- choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
- ('}', 'OptionBrace'))
+ choices = (
+ ("1", "Option1"),
+ ("2", "Option2"),
+ ("12", "Option12"),
+ ("}", "OptionBrace"),
+ )
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
- serializer = TestSerializer(data={'test_field': ['12']})
+ serializer = TestSerializer(data={"test_field": ["12"]})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
- self.assertInHTML('',
- result)
+ self.assertInHTML('', result)
self.assertInHTML('', result)
self.assertInHTML('', result)
self.assertInHTML('', result)
def test_render_selected_option_with_integer_option_ids(self):
- choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
+ choices = ((1, "Option1"), (2, "Option2"), (12, "Option12"))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
- serializer = TestSerializer(data={'test_field': ['12']})
+ serializer = TestSerializer(data={"test_field": ["12"]})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
- self.assertInHTML('',
- result)
+ self.assertInHTML('', result)
self.assertInHTML('', result)
self.assertInHTML('', result)
@@ -609,21 +633,22 @@ class StaticHTMLRendererTests(TestCase):
"""
Tests specific for Static HTML Renderer
"""
+
def setUp(self):
self.renderer = StaticHTMLRenderer()
def test_static_renderer(self):
- data = 'text'
+ data = "text"
result = self.renderer.render(data)
assert result == data
def test_static_renderer_with_exception(self):
context = {
- 'response': Response(status=500, exception=True),
- 'request': Request(HttpRequest())
+ "response": Response(status=500, exception=True),
+ "request": Request(HttpRequest()),
}
result = self.renderer.render({}, renderer_context=context)
- assert result == '500 Internal Server Error'
+ assert result == "500 Internal Server Error"
class BrowsableAPIRendererTests(URLPatternsTestCase):
@@ -636,222 +661,224 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
raise NotImplementedError
router = SimpleRouter()
- router.register('examples', ExampleViewSet, basename='example')
- urlpatterns = [url(r'^api/', include(router.urls))]
+ router.register("examples", ExampleViewSet, basename="example")
+ urlpatterns = [url(r"^api/", include(router.urls))]
def setUp(self):
self.renderer = BrowsableAPIRenderer()
def test_get_description_returns_empty_string_for_401_and_403_statuses(self):
- assert self.renderer.get_description({}, status_code=401) == ''
- assert self.renderer.get_description({}, status_code=403) == ''
+ assert self.renderer.get_description({}, status_code=401) == ""
+ assert self.renderer.get_description({}, status_code=403) == ""
def test_get_filter_form_returns_none_if_data_is_not_list_instance(self):
class DummyView(object):
get_queryset = None
filter_backends = None
- result = self.renderer.get_filter_form(data='not list',
- view=DummyView(), request={})
+ result = self.renderer.get_filter_form(
+ data="not list", view=DummyView(), request={}
+ )
assert result is None
def test_extra_actions_dropdown(self):
- resp = self.client.get('/api/examples/', HTTP_ACCEPT='text/html')
- assert 'id="extra-actions-menu"' in resp.content.decode('utf-8')
- assert '/api/examples/list_action/' in resp.content.decode('utf-8')
- assert '>Extra list action<' in resp.content.decode('utf-8')
+ resp = self.client.get("/api/examples/", HTTP_ACCEPT="text/html")
+ assert 'id="extra-actions-menu"' in resp.content.decode("utf-8")
+ assert "/api/examples/list_action/" in resp.content.decode("utf-8")
+ assert ">Extra list action<" in resp.content.decode("utf-8")
class AdminRendererTests(TestCase):
-
def setUp(self):
self.renderer = AdminRenderer()
def test_render_when_resource_created(self):
class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
- request = Request(HttpRequest())
- request.build_absolute_uri = lambda: 'http://example.com'
- response = Response(status=201, headers={'Location': '/test'})
- context = {
- 'view': DummyView(),
- 'request': request,
- 'response': response
- }
+ renderer_classes = (AdminRenderer,)
- result = self.renderer.render(data={'test': 'test'},
- renderer_context=context)
- assert result == ''
+ request = Request(HttpRequest())
+ request.build_absolute_uri = lambda: "http://example.com"
+ response = Response(status=201, headers={"Location": "/test"})
+ context = {"view": DummyView(), "request": request, "response": response}
+
+ result = self.renderer.render(data={"test": "test"}, renderer_context=context)
+ assert result == ""
assert response.status_code == status.HTTP_303_SEE_OTHER
- assert response['Location'] == 'http://example.com'
+ assert response["Location"] == "http://example.com"
def test_render_dict(self):
factory = APIRequestFactory()
class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
+ renderer_classes = (AdminRenderer,)
def get(self, request):
- return Response({'foo': 'a string'})
+ return Response({"foo": "a string"})
+
view = DummyView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
response.render()
- self.assertContains(response, 'Foo | a string |
', html=True)
+ self.assertContains(
+ response, "Foo | a string |
", html=True
+ )
def test_render_dict_with_items_key(self):
factory = APIRequestFactory()
class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
+ renderer_classes = (AdminRenderer,)
def get(self, request):
- return Response({'items': 'a string'})
+ return Response({"items": "a string"})
view = DummyView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
response.render()
- self.assertContains(response, 'Items | a string |
', html=True)
+ self.assertContains(
+ response, "Items | a string |
", html=True
+ )
def test_render_dict_with_iteritems_key(self):
factory = APIRequestFactory()
class DummyView(APIView):
- renderer_classes = (AdminRenderer, )
+ renderer_classes = (AdminRenderer,)
def get(self, request):
- return Response({'iteritems': 'a string'})
+ return Response({"iteritems": "a string"})
view = DummyView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
response.render()
- self.assertContains(response, 'Iteritems | a string |
', html=True)
+ self.assertContains(
+ response, "Iteritems | a string |
", html=True
+ )
def test_get_result_url(self):
factory = APIRequestFactory()
class DummyGenericViewsetLike(APIView):
- lookup_field = 'test'
+ lookup_field = "test"
def reverse_action(view, *args, **kwargs):
- self.assertEqual(kwargs['kwargs']['test'], 1)
- return '/example/'
+ self.assertEqual(kwargs["kwargs"]["test"], 1)
+ return "/example/"
# get the view instance instead of the view function
view = DummyGenericViewsetLike.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
- view = response.renderer_context['view']
+ view = response.renderer_context["view"]
- self.assertEqual(self.renderer.get_result_url({'test': 1}, view), '/example/')
+ self.assertEqual(self.renderer.get_result_url({"test": 1}, view), "/example/")
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_result_url_no_result(self):
factory = APIRequestFactory()
class DummyView(APIView):
- lookup_field = 'test'
+ lookup_field = "test"
# get the view instance instead of the view function
view = DummyView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
- view = response.renderer_context['view']
+ view = response.renderer_context["view"]
- self.assertIsNone(self.renderer.get_result_url({'test': 1}, view))
+ self.assertIsNone(self.renderer.get_result_url({"test": 1}, view))
self.assertIsNone(self.renderer.get_result_url({}, view))
def test_get_context_result_urls(self):
factory = APIRequestFactory()
class DummyView(APIView):
- lookup_field = 'test'
+ lookup_field = "test"
def reverse_action(view, url_name, args=None, kwargs=None):
- return '/%s/%d' % (url_name, kwargs['test'])
+ return "/%s/%d" % (url_name, kwargs["test"])
# get the view instance instead of the view function
view = DummyView.as_view()
- request = factory.get('/')
+ request = factory.get("/")
response = view(request)
data = [
- {'test': 1},
- {'url': '/example', 'test': 2},
- {'url': None, 'test': 3},
+ {"test": 1},
+ {"url": "/example", "test": 2},
+ {"url": None, "test": 3},
{},
]
context = {
- 'view': DummyView(),
- 'request': Request(request),
- 'response': response
+ "view": DummyView(),
+ "request": Request(request),
+ "response": response,
}
context = self.renderer.get_context(data, None, context)
- results = context['results']
+ results = context["results"]
self.assertEqual(len(results), 4)
- self.assertEqual(results[0]['url'], '/detail/1')
- self.assertEqual(results[1]['url'], '/example')
- self.assertEqual(results[2]['url'], None)
- self.assertNotIn('url', results[3])
+ self.assertEqual(results[0]["url"], "/detail/1")
+ self.assertEqual(results[1]["url"], "/example")
+ self.assertEqual(results[2]["url"], None)
+ self.assertNotIn("url", results[3])
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class TestDocumentationRenderer(TestCase):
-
def test_document_with_link_named_data(self):
"""
Ref #5395: Doc's `document.data` would fail with a Link named "data".
As per #4972, use templatetag instead.
"""
document = coreapi.Document(
- title='Data Endpoint API',
- url='https://api.example.org/',
+ title="Data Endpoint API",
+ url="https://api.example.org/",
content={
- 'data': coreapi.Link(
- url='/data/',
- action='get',
- fields=[],
- description='Return data.'
+ "data": coreapi.Link(
+ url="/data/", action="get", fields=[], description="Return data."
)
- }
+ },
)
factory = APIRequestFactory()
- request = factory.get('/')
+ request = factory.get("/")
renderer = DocumentationRenderer()
- html = renderer.render(document, accepted_media_type="text/html", renderer_context={"request": request})
- assert 'Data Endpoint API
' in html
+ html = renderer.render(
+ document,
+ accepted_media_type="text/html",
+ renderer_context={"request": request},
+ )
+ assert "Data Endpoint API
" in html
def test_shell_code_example_rendering(self):
- template = loader.get_template('rest_framework/docs/langs/shell.html')
+ template = loader.get_template("rest_framework/docs/langs/shell.html")
context = {
- 'document': coreapi.Document(url='https://api.example.org/'),
- 'link_key': 'testcases > list',
- 'link': coreapi.Link(url='/data/', action='get', fields=[]),
+ "document": coreapi.Document(url="https://api.example.org/"),
+ "link_key": "testcases > list",
+ "link": coreapi.Link(url="/data/", action="get", fields=[]),
}
html = template.render(context)
- assert 'testcases list' in html
+ assert "testcases list" in html
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class TestSchemaJSRenderer(TestCase):
-
def test_schemajs_output(self):
"""
Test output of the SchemaJS renderer as per #5608. Django 2.0 on Py3 prints binary data as b'xyz' in templates,
and the base64 encoding used by SchemaJSRenderer outputs base64 as binary. Test fix.
"""
factory = APIRequestFactory()
- request = factory.get('/')
+ request = factory.get("/")
renderer = SchemaJSRenderer()
- output = renderer.render('data', renderer_context={"request": request})
+ output = renderer.render("data", renderer_context={"request": request})
assert "'ImRhdGEi'" in output
assert "'b'ImRhdGEi''" not in output
diff --git a/tests/test_request.py b/tests/test_request.py
index 83d295a12..bfcd0769a 100644
--- a/tests/test_request.py
+++ b/tests/test_request.py
@@ -25,23 +25,24 @@ from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView
+
factory = APIRequestFactory()
class TestInitializer(TestCase):
def test_request_type(self):
- request = Request(factory.get('/'))
+ request = Request(factory.get("/"))
message = (
- 'The `request` argument must be an instance of '
- '`django.http.HttpRequest`, not `rest_framework.request.Request`.'
+ "The `request` argument must be an instance of "
+ "`django.http.HttpRequest`, not `rest_framework.request.Request`."
)
with self.assertRaisesMessage(AssertionError, message):
Request(request)
class PlainTextParser(BaseParser):
- media_type = 'text/plain'
+ media_type = "text/plain"
def parse(self, stream, media_type=None, parser_context=None):
"""
@@ -58,22 +59,22 @@ class TestContentParsing(TestCase):
"""
Ensure request.data returns empty QueryDict for GET request.
"""
- request = Request(factory.get('/'))
+ request = Request(factory.get("/"))
assert request.data == {}
def test_standard_behaviour_determines_no_content_HEAD(self):
"""
Ensure request.data returns empty QueryDict for HEAD request.
"""
- request = Request(factory.head('/'))
+ request = Request(factory.head("/"))
assert request.data == {}
def test_request_DATA_with_form_content(self):
"""
Ensure request.data returns content for POST request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
+ data = {"qwerty": "uiop"}
+ request = Request(factory.post("/", data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
@@ -82,9 +83,9 @@ class TestContentParsing(TestCase):
Ensure request.data returns content for POST request with
non-form content.
"""
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.post('/', content, content_type=content_type))
+ content = six.b("qwerty")
+ content_type = "text/plain"
+ request = Request(factory.post("/", content, content_type=content_type))
request.parsers = (PlainTextParser(),)
assert request.data == content
@@ -92,8 +93,8 @@ class TestContentParsing(TestCase):
"""
Ensure request.POST returns content for POST request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.post('/', data))
+ data = {"qwerty": "uiop"}
+ request = Request(factory.post("/", data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST.items()) == list(data.items())
@@ -102,17 +103,17 @@ class TestContentParsing(TestCase):
Ensure request.POST returns no content for POST request with file content.
"""
upload = SimpleUploadedFile("file.txt", b"file_content")
- request = Request(factory.post('/', {'upload': upload}))
+ request = Request(factory.post("/", {"upload": upload}))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.POST) == []
- assert list(request.FILES) == ['upload']
+ assert list(request.FILES) == ["upload"]
def test_standard_behaviour_determines_form_content_PUT(self):
"""
Ensure request.data returns content for PUT request with form content.
"""
- data = {'qwerty': 'uiop'}
- request = Request(factory.put('/', data))
+ data = {"qwerty": "uiop"}
+ request = Request(factory.put("/", data))
request.parsers = (FormParser(), MultiPartParser())
assert list(request.data.items()) == list(data.items())
@@ -121,10 +122,10 @@ class TestContentParsing(TestCase):
Ensure request.data returns content for PUT request with
non-form content.
"""
- content = six.b('qwerty')
- content_type = 'text/plain'
- request = Request(factory.put('/', content, content_type=content_type))
- request.parsers = (PlainTextParser(), )
+ content = six.b("qwerty")
+ content_type = "text/plain"
+ request = Request(factory.put("/", content, content_type=content_type))
+ request.parsers = (PlainTextParser(),)
assert request.data == content
@@ -132,7 +133,7 @@ class MockView(APIView):
authentication_classes = (SessionAuthentication,)
def post(self, request):
- if request.POST.get('example') is not None:
+ if request.POST.get("example") is not None:
return Response(status=status.HTTP_200_OK)
return Response(status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@@ -154,20 +155,20 @@ class FileUploadView(APIView):
urlpatterns = [
- url(r'^$', MockView.as_view()),
- url(r'^echo/$', EchoView.as_view()),
- url(r'^upload/$', FileUploadView.as_view())
+ url(r"^$", MockView.as_view()),
+ url(r"^echo/$", EchoView.as_view()),
+ url(r"^upload/$", FileUploadView.as_view()),
]
@override_settings(
- ROOT_URLCONF='tests.test_request',
- FILE_UPLOAD_HANDLERS=['django.core.files.uploadhandler.TemporaryFileUploadHandler'])
+ ROOT_URLCONF="tests.test_request",
+ FILE_UPLOAD_HANDLERS=["django.core.files.uploadhandler.TemporaryFileUploadHandler"],
+)
class FileUploadTests(TestCase):
-
def test_fileuploads_closed_at_request_end(self):
with tempfile.NamedTemporaryFile() as f:
- response = self.client.post('/upload/', {'file': f})
+ response = self.client.post("/upload/", {"file": f})
# sanity check that file was processed
assert len(response.data) == 1
@@ -176,13 +177,13 @@ class FileUploadTests(TestCase):
assert not os.path.exists(file)
-@override_settings(ROOT_URLCONF='tests.test_request')
+@override_settings(ROOT_URLCONF="tests.test_request")
class TestContentParsingWithAuthentication(TestCase):
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
- self.username = 'john'
- self.email = 'lennon@thebeatles.com'
- self.password = 'password'
+ self.username = "john"
+ self.email = "lennon@thebeatles.com"
+ self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
@@ -190,27 +191,26 @@ class TestContentParsingWithAuthentication(TestCase):
Ensures request.POST exists after SessionAuthentication when user
doesn't log in.
"""
- content = {'example': 'example'}
+ content = {"example": "example"}
- response = self.client.post('/', content)
+ response = self.client.post("/", content)
assert status.HTTP_200_OK == response.status_code
- response = self.csrf_client.post('/', content)
+ response = self.csrf_client.post("/", content)
assert status.HTTP_200_OK == response.status_code
class TestUserSetter(TestCase):
-
def setUp(self):
# Pass request object through session middleware so session is
# available to login and logout functions
- self.wrapped_request = factory.get('/')
+ self.wrapped_request = factory.get("/")
self.request = Request(self.wrapped_request)
SessionMiddleware().process_request(self.wrapped_request)
AuthenticationMiddleware().process_request(self.wrapped_request)
- User.objects.create_user('ringo', 'starr@thebeatles.com', 'yellow')
- self.user = authenticate(username='ringo', password='yellow')
+ User.objects.create_user("ringo", "starr@thebeatles.com", "yellow")
+ self.user = authenticate(username="ringo", password="yellow")
def test_user_can_be_set(self):
self.request.user = self.user
@@ -235,11 +235,14 @@ class TestUserSetter(TestCase):
This proves that when an AttributeError is raised inside of the request.user
property, that we can handle this and report the true, underlying error.
"""
+
class AuthRaisesAttributeError(object):
def authenticate(self, request):
self.MISSPELLED_NAME_THAT_DOESNT_EXIST
- request = Request(self.wrapped_request, authenticators=(AuthRaisesAttributeError(),))
+ request = Request(
+ self.wrapped_request, authenticators=(AuthRaisesAttributeError(),)
+ )
# The middleware processes the underlying Django request, sets anonymous user
assert self.wrapped_request.user.is_anonymous
@@ -254,7 +257,7 @@ class TestUserSetter(TestCase):
return
with pytest.raises(WrappedAttributeError, match=expected):
- hasattr(request, 'user')
+ hasattr(request, "user")
with pytest.raises(WrappedAttributeError, match=expected):
login(request, self.user)
@@ -262,25 +265,24 @@ class TestUserSetter(TestCase):
class TestAuthSetter(TestCase):
def test_auth_can_be_set(self):
- request = Request(factory.get('/'))
- request.auth = 'DUMMY'
- assert request.auth == 'DUMMY'
+ request = Request(factory.get("/"))
+ request.auth = "DUMMY"
+ assert request.auth == "DUMMY"
class TestSecure(TestCase):
-
def test_default_secure_false(self):
- request = Request(factory.get('/', secure=False))
- assert request.scheme == 'http'
+ request = Request(factory.get("/", secure=False))
+ assert request.scheme == "http"
def test_default_secure_true(self):
- request = Request(factory.get('/', secure=True))
- assert request.scheme == 'https'
+ request = Request(factory.get("/", secure=True))
+ assert request.scheme == "https"
class TestHttpRequest(TestCase):
def test_attribute_access_proxy(self):
- http_request = factory.get('/')
+ http_request = factory.get("/")
request = Request(http_request)
inner_sentinel = object()
@@ -293,31 +295,31 @@ class TestHttpRequest(TestCase):
def test_exception_proxy(self):
# ensure the exception message is not for the underlying WSGIRequest
- http_request = factory.get('/')
+ http_request = factory.get("/")
request = Request(http_request)
message = "'Request' object has no attribute 'inner_property'"
with self.assertRaisesMessage(AttributeError, message):
request.inner_property
- @override_settings(ROOT_URLCONF='tests.test_request')
+ @override_settings(ROOT_URLCONF="tests.test_request")
def test_duplicate_request_stream_parsing_exception(self):
"""
Check assumption that duplicate stream parsing will result in a
`RawPostDataException` being raised.
"""
- response = APIClient().post('/echo/', data={'a': 'b'}, format='json')
- request = response.renderer_context['request']
+ response = APIClient().post("/echo/", data={"a": "b"}, format="json")
+ request = response.renderer_context["request"]
# ensure that request stream was consumed by json parser
- assert request.content_type.startswith('application/json')
- assert response.data == {'a': 'b'}
+ assert request.content_type.startswith("application/json")
+ assert response.data == {"a": "b"}
# pass same HttpRequest to view, stream already consumed
with pytest.raises(RawPostDataException):
EchoView.as_view()(request._request)
- @override_settings(ROOT_URLCONF='tests.test_request')
+ @override_settings(ROOT_URLCONF="tests.test_request")
def test_duplicate_request_form_data_access(self):
"""
Form data is copied to the underlying django request for middleware
@@ -325,17 +327,17 @@ class TestHttpRequest(TestCase):
data is 'safe' in so far as accessing `request.POST` does not trigger
the duplicate stream parse exception.
"""
- response = APIClient().post('/echo/', data={'a': 'b'})
- request = response.renderer_context['request']
+ response = APIClient().post("/echo/", data={"a": "b"})
+ request = response.renderer_context["request"]
# ensure that request stream was consumed by form parser
- assert request.content_type.startswith('multipart/form-data')
- assert response.data == {'a': ['b']}
+ assert request.content_type.startswith("multipart/form-data")
+ assert response.data == {"a": ["b"]}
# pass same HttpRequest to view, form data set on underlying request
response = EchoView.as_view()(request._request)
- request = response.renderer_context['request']
+ request = response.renderer_context["request"]
# ensure that request stream was consumed by form parser
- assert request.content_type.startswith('multipart/form-data')
- assert response.data == {'a': ['b']}
+ assert request.content_type.startswith("multipart/form-data")
+ assert response.data == {"a": ["b"]}
diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py
index 161429f73..e0c376097 100644
--- a/tests/test_requests_client.py
+++ b/tests/test_requests_client.py
@@ -18,55 +18,48 @@ from rest_framework.views import APIView
class Root(APIView):
def get(self, request):
- return Response({
- 'method': request.method,
- 'query_params': request.query_params,
- })
+ return Response(
+ {"method": request.method, "query_params": request.query_params}
+ )
def post(self, request):
files = {
- key: (value.name, value.read())
- for key, value in request.FILES.items()
+ key: (value.name, value.read()) for key, value in request.FILES.items()
}
post = request.POST
json = None
- if request.META.get('CONTENT_TYPE') == 'application/json':
+ if request.META.get("CONTENT_TYPE") == "application/json":
json = request.data
- return Response({
- 'method': request.method,
- 'query_params': request.query_params,
- 'POST': post,
- 'FILES': files,
- 'JSON': json
- })
+ return Response(
+ {
+ "method": request.method,
+ "query_params": request.query_params,
+ "POST": post,
+ "FILES": files,
+ "JSON": json,
+ }
+ )
class HeadersView(APIView):
def get(self, request):
headers = {
- key[5:].replace('_', '-'): value
+ key[5:].replace("_", "-"): value
for key, value in request.META.items()
- if key.startswith('HTTP_')
+ if key.startswith("HTTP_")
}
- return Response({
- 'method': request.method,
- 'headers': headers
- })
+ return Response({"method": request.method, "headers": headers})
class SessionView(APIView):
def get(self, request):
- return Response({
- key: value for key, value in request.session.items()
- })
+ return Response({key: value for key, value in request.session.items()})
def post(self, request):
for key, value in request.data.items():
request.session[key] = value
- return Response({
- key: value for key, value in request.session.items()
- })
+ return Response({key: value for key, value in request.session.items()})
class AuthView(APIView):
@@ -76,181 +69,167 @@ class AuthView(APIView):
username = request.user.username
else:
username = None
- return Response({
- 'username': username
- })
+ return Response({"username": username})
@method_decorator(csrf_protect)
def post(self, request):
- username = request.data['username']
- password = request.data['password']
+ username = request.data["username"]
+ password = request.data["password"]
user = authenticate(username=username, password=password)
if user is None:
- return Response({'error': 'incorrect credentials'})
+ return Response({"error": "incorrect credentials"})
login(request, user)
- return redirect('/auth/')
+ return redirect("/auth/")
urlpatterns = [
- url(r'^$', Root.as_view(), name='root'),
- url(r'^headers/$', HeadersView.as_view(), name='headers'),
- url(r'^session/$', SessionView.as_view(), name='session'),
- url(r'^auth/$', AuthView.as_view(), name='auth'),
+ url(r"^$", Root.as_view(), name="root"),
+ url(r"^headers/$", HeadersView.as_view(), name="headers"),
+ url(r"^session/$", SessionView.as_view(), name="session"),
+ url(r"^auth/$", AuthView.as_view(), name="auth"),
]
-@unittest.skipUnless(requests, 'requests not installed')
-@override_settings(ROOT_URLCONF='tests.test_requests_client')
+@unittest.skipUnless(requests, "requests not installed")
+@override_settings(ROOT_URLCONF="tests.test_requests_client")
class RequestsClientTests(APITestCase):
def test_get_request(self):
client = RequestsClient()
- response = client.get('http://testserver/')
+ response = client.get("http://testserver/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'method': 'GET',
- 'query_params': {}
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"method": "GET", "query_params": {}}
assert response.json() == expected
def test_get_request_query_params_in_url(self):
client = RequestsClient()
- response = client.get('http://testserver/?key=value')
+ response = client.get("http://testserver/?key=value")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'method': 'GET',
- 'query_params': {'key': 'value'}
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"method": "GET", "query_params": {"key": "value"}}
assert response.json() == expected
def test_get_request_query_params_by_kwarg(self):
client = RequestsClient()
- response = client.get('http://testserver/', params={'key': 'value'})
+ response = client.get("http://testserver/", params={"key": "value"})
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'method': 'GET',
- 'query_params': {'key': 'value'}
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"method": "GET", "query_params": {"key": "value"}}
assert response.json() == expected
def test_get_with_headers(self):
client = RequestsClient()
- response = client.get('http://testserver/headers/', headers={'User-Agent': 'example'})
+ response = client.get(
+ "http://testserver/headers/", headers={"User-Agent": "example"}
+ )
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- headers = response.json()['headers']
- assert headers['USER-AGENT'] == 'example'
+ assert response.headers["Content-Type"] == "application/json"
+ headers = response.json()["headers"]
+ assert headers["USER-AGENT"] == "example"
def test_get_with_session_headers(self):
client = RequestsClient()
- client.headers.update({'User-Agent': 'example'})
- response = client.get('http://testserver/headers/')
+ client.headers.update({"User-Agent": "example"})
+ response = client.get("http://testserver/headers/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- headers = response.json()['headers']
- assert headers['USER-AGENT'] == 'example'
+ assert response.headers["Content-Type"] == "application/json"
+ headers = response.json()["headers"]
+ assert headers["USER-AGENT"] == "example"
def test_post_form_request(self):
client = RequestsClient()
- response = client.post('http://testserver/', data={'key': 'value'})
+ response = client.post("http://testserver/", data={"key": "value"})
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
+ assert response.headers["Content-Type"] == "application/json"
expected = {
- 'method': 'POST',
- 'query_params': {},
- 'POST': {'key': 'value'},
- 'FILES': {},
- 'JSON': None
+ "method": "POST",
+ "query_params": {},
+ "POST": {"key": "value"},
+ "FILES": {},
+ "JSON": None,
}
assert response.json() == expected
def test_post_json_request(self):
client = RequestsClient()
- response = client.post('http://testserver/', json={'key': 'value'})
+ response = client.post("http://testserver/", json={"key": "value"})
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
+ assert response.headers["Content-Type"] == "application/json"
expected = {
- 'method': 'POST',
- 'query_params': {},
- 'POST': {},
- 'FILES': {},
- 'JSON': {'key': 'value'}
+ "method": "POST",
+ "query_params": {},
+ "POST": {},
+ "FILES": {},
+ "JSON": {"key": "value"},
}
assert response.json() == expected
def test_post_multipart_request(self):
client = RequestsClient()
- files = {
- 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n')
- }
- response = client.post('http://testserver/', files=files)
+ files = {"file": ("report.csv", "some,data,to,send\nanother,row,to,send\n")}
+ response = client.post("http://testserver/", files=files)
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
+ assert response.headers["Content-Type"] == "application/json"
expected = {
- 'method': 'POST',
- 'query_params': {},
- 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']},
- 'POST': {},
- 'JSON': None
+ "method": "POST",
+ "query_params": {},
+ "FILES": {
+ "file": ["report.csv", "some,data,to,send\nanother,row,to,send\n"]
+ },
+ "POST": {},
+ "JSON": None,
}
assert response.json() == expected
def test_session(self):
client = RequestsClient()
- response = client.get('http://testserver/session/')
+ response = client.get("http://testserver/session/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
+ assert response.headers["Content-Type"] == "application/json"
expected = {}
assert response.json() == expected
- response = client.post('http://testserver/session/', json={'example': 'abc'})
+ response = client.post("http://testserver/session/", json={"example": "abc"})
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {'example': 'abc'}
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"example": "abc"}
assert response.json() == expected
- response = client.get('http://testserver/session/')
+ response = client.get("http://testserver/session/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {'example': 'abc'}
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"example": "abc"}
assert response.json() == expected
def test_auth(self):
# Confirm session is not authenticated
client = RequestsClient()
- response = client.get('http://testserver/auth/')
+ response = client.get("http://testserver/auth/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'username': None
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"username": None}
assert response.json() == expected
- assert 'csrftoken' in response.cookies
- csrftoken = response.cookies['csrftoken']
+ assert "csrftoken" in response.cookies
+ csrftoken = response.cookies["csrftoken"]
- user = User.objects.create(username='tom')
- user.set_password('password')
+ user = User.objects.create(username="tom")
+ user.set_password("password")
user.save()
# Perform a login
- response = client.post('http://testserver/auth/', json={
- 'username': 'tom',
- 'password': 'password'
- }, headers={'X-CSRFToken': csrftoken})
+ response = client.post(
+ "http://testserver/auth/",
+ json={"username": "tom", "password": "password"},
+ headers={"X-CSRFToken": csrftoken},
+ )
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'username': 'tom'
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"username": "tom"}
assert response.json() == expected
# Confirm session is authenticated
- response = client.get('http://testserver/auth/')
+ response = client.get("http://testserver/auth/")
assert response.status_code == 200
- assert response.headers['Content-Type'] == 'application/json'
- expected = {
- 'username': 'tom'
- }
+ assert response.headers["Content-Type"] == "application/json"
+ expected = {"username": "tom"}
assert response.json() == expected
diff --git a/tests/test_response.py b/tests/test_response.py
index e92bf54c1..9c37d91ce 100644
--- a/tests/test_response.py
+++ b/tests/test_response.py
@@ -6,9 +6,7 @@ from django.utils import six
from rest_framework import generics, routers, serializers, status, viewsets
from rest_framework.parsers import JSONParser
-from rest_framework.renderers import (
- BaseRenderer, BrowsableAPIRenderer, JSONRenderer
-)
+from rest_framework.renderers import BaseRenderer, BrowsableAPIRenderer, JSONRenderer
from rest_framework.response import Response
from rest_framework.views import APIView
from tests.models import BasicModel
@@ -18,35 +16,35 @@ from tests.models import BasicModel
class BasicModelSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
- fields = '__all__'
+ fields = "__all__"
class MockPickleRenderer(BaseRenderer):
- media_type = 'application/pickle'
+ media_type = "application/pickle"
class MockJsonRenderer(BaseRenderer):
- media_type = 'application/json'
+ media_type = "application/json"
class MockTextMediaRenderer(BaseRenderer):
- media_type = 'text/html'
+ media_type = "text/html"
DUMMYSTATUS = status.HTTP_200_OK
-DUMMYCONTENT = 'dummycontent'
+DUMMYCONTENT = "dummycontent"
def RENDERER_A_SERIALIZER(x):
- return ('Renderer A: %s' % x).encode('ascii')
+ return ("Renderer A: %s" % x).encode("ascii")
def RENDERER_B_SERIALIZER(x):
- return ('Renderer B: %s' % x).encode('ascii')
+ return ("Renderer B: %s" % x).encode("ascii")
class RendererA(BaseRenderer):
- media_type = 'mock/renderera'
+ media_type = "mock/renderera"
format = "formata"
def render(self, data, media_type=None, renderer_context=None):
@@ -54,7 +52,7 @@ class RendererA(BaseRenderer):
class RendererB(BaseRenderer):
- media_type = 'mock/rendererb'
+ media_type = "mock/rendererb"
format = "formatb"
def render(self, data, media_type=None, renderer_context=None):
@@ -62,8 +60,8 @@ class RendererB(BaseRenderer):
class RendererC(RendererB):
- media_type = 'mock/rendererc'
- format = 'formatc'
+ media_type = "mock/rendererc"
+ format = "formatc"
charset = "rendererc"
@@ -78,7 +76,7 @@ class MockViewSettingContentType(APIView):
renderer_classes = (RendererA, RendererB, RendererC)
def get(self, request, **kwargs):
- return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
+ return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type="setbyview")
class JSONView(APIView):
@@ -90,17 +88,17 @@ class JSONView(APIView):
class HTMLView(APIView):
- renderer_classes = (BrowsableAPIRenderer, )
+ renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs):
- return Response('text')
+ return Response("text")
class HTMLView1(APIView):
renderer_classes = (BrowsableAPIRenderer, JSONRenderer)
def get(self, request, **kwargs):
- return Response('text')
+ return Response("text")
class HTMLNewModelViewSet(viewsets.ModelViewSet):
@@ -116,152 +114,169 @@ class HTMLNewModelView(generics.ListCreateAPIView):
new_model_viewset_router = routers.DefaultRouter()
-new_model_viewset_router.register(r'', HTMLNewModelViewSet)
+new_model_viewset_router.register(r"", HTMLNewModelViewSet)
urlpatterns = [
- url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^.*\.(?P.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
- url(r'^html$', HTMLView.as_view()),
- url(r'^json$', JSONView.as_view()),
- url(r'^html1$', HTMLView1.as_view()),
- url(r'^html_new_model$', HTMLNewModelView.as_view()),
- url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
- url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
+ url(
+ r"^setbyview$",
+ MockViewSettingContentType.as_view(
+ renderer_classes=[RendererA, RendererB, RendererC]
+ ),
+ ),
+ url(
+ r"^.*\.(?P.+)$",
+ MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC]),
+ ),
+ url(r"^$", MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
+ url(r"^html$", HTMLView.as_view()),
+ url(r"^json$", JSONView.as_view()),
+ url(r"^html1$", HTMLView1.as_view()),
+ url(r"^html_new_model$", HTMLNewModelView.as_view()),
+ url(r"^html_new_model_viewset", include(new_model_viewset_router.urls)),
+ url(r"^restframework", include("rest_framework.urls", namespace="rest_framework")),
]
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ...
-@override_settings(ROOT_URLCONF='tests.test_response')
+@override_settings(ROOT_URLCONF="tests.test_response")
class RendererIntegrationTests(TestCase):
"""
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
+
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
- resp = self.client.get('/')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/")
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_head_method_serializes_no_content(self):
"""No response must be included in HEAD requests."""
- resp = self.client.head('/')
+ resp = self.client.head("/")
self.assertEqual(resp.status_code, DUMMYSTATUS)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
- self.assertEqual(resp.content, six.b(''))
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
+ self.assertEqual(resp.content, six.b(""))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
- resp = self.client.get('/', HTTP_ACCEPT='*/*')
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT="*/*")
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
- self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT=RendererA.media_type)
+ self.assertEqual(resp["Content-Type"], RendererA.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
- resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get("/", HTTP_ACCEPT=RendererB.media_type)
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_query(self):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get("/?format=%s" % RendererB.format)
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
- resp = self.client.get('/something.formatb')
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get("/something.formatb")
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
- resp = self.client.get('/?format=%s' % RendererB.format,
- HTTP_ACCEPT=RendererB.media_type)
- self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
+ resp = self.client.get(
+ "/?format=%s" % RendererB.format, HTTP_ACCEPT=RendererB.media_type
+ )
+ self.assertEqual(resp["Content-Type"], RendererB.media_type + "; charset=utf-8")
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
-@override_settings(ROOT_URLCONF='tests.test_response')
+@override_settings(ROOT_URLCONF="tests.test_response")
class UnsupportedMediaTypeTests(TestCase):
def test_should_allow_posting_json(self):
- response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
+ response = self.client.post(
+ "/json", data='{"test": 123}', content_type="application/json"
+ )
self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self):
- response = self.client.post('/json', data='123', content_type='application/xml')
+ response = self.client.post(
+ "/json", data="123", content_type="application/xml"
+ )
self.assertEqual(response.status_code, 415)
def test_should_not_allow_posting_a_form(self):
- response = self.client.post('/json', data={'test': 123})
+ response = self.client.post("/json", data={"test": 123})
self.assertEqual(response.status_code, 415)
-@override_settings(ROOT_URLCONF='tests.test_response')
+@override_settings(ROOT_URLCONF="tests.test_response")
class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
+
def test_only_html_renderer(self):
"""
Test if no infinite recursion occurs.
"""
- self.client.get('/html')
+ self.client.get("/html")
def test_html_renderer_is_first(self):
"""
Test if no infinite recursion occurs.
"""
- self.client.get('/html1')
+ self.client.get("/html1")
-@override_settings(ROOT_URLCONF='tests.test_response')
+@override_settings(ROOT_URLCONF="tests.test_response")
class Issue467Tests(TestCase):
"""
Tests for #467
"""
+
def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ resp = self.client.get("/html_new_model")
+ self.assertEqual(resp["Content-Type"], "text/html; charset=utf-8")
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
-@override_settings(ROOT_URLCONF='tests.test_response')
+@override_settings(ROOT_URLCONF="tests.test_response")
class Issue807Tests(TestCase):
"""
Covers #807
"""
+
def test_does_not_append_charset_by_default(self):
"""
Renderers don't include a charset unless set explicitly.
"""
headers = {"HTTP_ACCEPT": RendererA.media_type}
- resp = self.client.get('/', **headers)
- expected = "{0}; charset={1}".format(RendererA.media_type, 'utf-8')
- self.assertEqual(expected, resp['Content-Type'])
+ resp = self.client.get("/", **headers)
+ expected = "{0}; charset={1}".format(RendererA.media_type, "utf-8")
+ self.assertEqual(expected, resp["Content-Type"])
def test_if_there_is_charset_specified_on_renderer_it_gets_appended(self):
"""
@@ -269,20 +284,20 @@ class Issue807Tests(TestCase):
to Response's Content-Type
"""
headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/', **headers)
+ resp = self.client.get("/", **headers)
expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
- self.assertEqual(expected, resp['Content-Type'])
+ self.assertEqual(expected, resp["Content-Type"])
def test_content_type_set_explicitly_on_response(self):
"""
The content type may be set explicitly on the response.
"""
headers = {"HTTP_ACCEPT": RendererC.media_type}
- resp = self.client.get('/setbyview', **headers)
- self.assertEqual('setbyview', resp['Content-Type'])
+ resp = self.client.get("/setbyview", **headers)
+ self.assertEqual("setbyview", resp["Content-Type"])
def test_form_has_label_and_help_text(self):
- resp = self.client.get('/html_new_model')
- self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
+ resp = self.client.get("/html_new_model")
+ self.assertEqual(resp["Content-Type"], "text/html; charset=utf-8")
# self.assertContains(resp, 'Text comes here')
# self.assertContains(resp, 'Text description.')
diff --git a/tests/test_reverse.py b/tests/test_reverse.py
index 145b1a54f..ad78b32a5 100644
--- a/tests/test_reverse.py
+++ b/tests/test_reverse.py
@@ -7,6 +7,7 @@ from django.urls import NoReverseMatch
from rest_framework.reverse import reverse
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
@@ -14,13 +15,10 @@ def null_view(request):
pass
-urlpatterns = [
- url(r'^view$', null_view, name='view'),
-]
+urlpatterns = [url(r"^view$", null_view, name="view")]
class MockVersioningScheme(object):
-
def __init__(self, raise_error=False):
self.raise_error = raise_error
@@ -28,29 +26,30 @@ class MockVersioningScheme(object):
if self.raise_error:
raise NoReverseMatch()
- return 'http://scheme-reversed/view'
+ return "http://scheme-reversed/view"
-@override_settings(ROOT_URLCONF='tests.test_reverse')
+@override_settings(ROOT_URLCONF="tests.test_reverse")
class ReverseTests(TestCase):
"""
Tests for fully qualified URLs when using `reverse`.
"""
+
def test_reversed_urls_are_fully_qualified(self):
- request = factory.get('/view')
- url = reverse('view', request=request)
- assert url == 'http://testserver/view'
+ request = factory.get("/view")
+ url = reverse("view", request=request)
+ assert url == "http://testserver/view"
def test_reverse_with_versioning_scheme(self):
- request = factory.get('/view')
+ request = factory.get("/view")
request.versioning_scheme = MockVersioningScheme()
- url = reverse('view', request=request)
- assert url == 'http://scheme-reversed/view'
+ url = reverse("view", request=request)
+ assert url == "http://scheme-reversed/view"
def test_reverse_with_versioning_scheme_fallback_to_default_on_error(self):
- request = factory.get('/view')
+ request = factory.get("/view")
request.versioning_scheme = MockVersioningScheme(raise_error=True)
- url = reverse('view', request=request)
- assert url == 'http://testserver/view'
+ url = reverse("view", request=request)
+ assert url == "http://testserver/view"
diff --git a/tests/test_routers.py b/tests/test_routers.py
index cca2ea712..eb7c0fff1 100644
--- a/tests/test_routers.py
+++ b/tests/test_routers.py
@@ -10,9 +10,7 @@ from django.db import models
from django.test import TestCase, override_settings
from django.urls import resolve, reverse
-from rest_framework import (
- RemovedInDRF311Warning, permissions, serializers, viewsets
-)
+from rest_framework import RemovedInDRF311Warning, permissions, serializers, viewsets
from rest_framework.compat import get_regex_pattern
from rest_framework.decorators import action
from rest_framework.response import Response
@@ -20,6 +18,7 @@ from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.test import APIRequestFactory, URLPatternsTestCase
from rest_framework.utils import json
+
factory = APIRequestFactory()
@@ -29,24 +28,26 @@ class RouterTestModel(models.Model):
class NoteSerializer(serializers.HyperlinkedModelSerializer):
- url = serializers.HyperlinkedIdentityField(view_name='routertestmodel-detail', lookup_field='uuid')
+ url = serializers.HyperlinkedIdentityField(
+ view_name="routertestmodel-detail", lookup_field="uuid"
+ )
class Meta:
model = RouterTestModel
- fields = ('url', 'uuid', 'text')
+ fields = ("url", "uuid", "text")
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
serializer_class = NoteSerializer
- lookup_field = 'uuid'
+ lookup_field = "uuid"
class KWargedNoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
serializer_class = NoteSerializer
- lookup_field = 'text__contains'
- lookup_url_kwarg = 'text'
+ lookup_field = "text__contains"
+ lookup_url_kwarg = "text"
class MockViewSet(viewsets.ModelViewSet):
@@ -57,75 +58,76 @@ class MockViewSet(viewsets.ModelViewSet):
class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RouterTestModel
- fields = ('uuid', 'text')
+ fields = ("uuid", "text")
class EmptyPrefixViewSet(viewsets.ModelViewSet):
- queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')]
+ queryset = [
+ RouterTestModel(id=1, uuid="111", text="First"),
+ RouterTestModel(id=2, uuid="222", text="Second"),
+ ]
serializer_class = EmptyPrefixSerializer
def get_object(self, *args, **kwargs):
- index = int(self.kwargs['pk']) - 1
+ index = int(self.kwargs["pk"]) - 1
return self.queryset[index]
class RegexUrlPathViewSet(viewsets.ViewSet):
- @action(detail=False, url_path='list/(?P[0-9]{4})')
+ @action(detail=False, url_path="list/(?P[0-9]{4})")
def regex_url_path_list(self, request, *args, **kwargs):
- kwarg = self.kwargs.get('kwarg', '')
- return Response({'kwarg': kwarg})
+ kwarg = self.kwargs.get("kwarg", "")
+ return Response({"kwarg": kwarg})
- @action(detail=True, url_path='detail/(?P[0-9]{4})')
+ @action(detail=True, url_path="detail/(?P[0-9]{4})")
def regex_url_path_detail(self, request, *args, **kwargs):
- pk = self.kwargs.get('pk', '')
- kwarg = self.kwargs.get('kwarg', '')
- return Response({'pk': pk, 'kwarg': kwarg})
+ pk = self.kwargs.get("pk", "")
+ kwarg = self.kwargs.get("kwarg", "")
+ return Response({"pk": pk, "kwarg": kwarg})
notes_router = SimpleRouter()
-notes_router.register(r'notes', NoteViewSet)
+notes_router.register(r"notes", NoteViewSet)
kwarged_notes_router = SimpleRouter()
-kwarged_notes_router.register(r'notes', KWargedNoteViewSet)
+kwarged_notes_router.register(r"notes", KWargedNoteViewSet)
namespaced_router = DefaultRouter()
-namespaced_router.register(r'example', MockViewSet, basename='example')
+namespaced_router.register(r"example", MockViewSet, basename="example")
empty_prefix_router = SimpleRouter()
-empty_prefix_router.register(r'', EmptyPrefixViewSet, basename='empty_prefix')
+empty_prefix_router.register(r"", EmptyPrefixViewSet, basename="empty_prefix")
regex_url_path_router = SimpleRouter()
-regex_url_path_router.register(r'', RegexUrlPathViewSet, basename='regex')
+regex_url_path_router.register(r"", RegexUrlPathViewSet, basename="regex")
class BasicViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
- return Response({'method': 'list'})
+ return Response({"method": "list"})
- @action(methods=['post'], detail=True)
+ @action(methods=["post"], detail=True)
def action1(self, request, *args, **kwargs):
- return Response({'method': 'action1'})
+ return Response({"method": "action1"})
- @action(methods=['post', 'delete'], detail=True)
+ @action(methods=["post", "delete"], detail=True)
def action2(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
+ return Response({"method": "action2"})
- @action(methods=['post'], detail=True)
+ @action(methods=["post"], detail=True)
def action3(self, request, pk, *args, **kwargs):
- return Response({'post': pk})
+ return Response({"post": pk})
@action3.mapping.delete
def action3_delete(self, request, pk, *args, **kwargs):
- return Response({'delete': pk})
+ return Response({"delete": pk})
class TestSimpleRouter(URLPatternsTestCase, TestCase):
router = SimpleRouter()
- router.register('basics', BasicViewSet, basename='basic')
+ router.register("basics", BasicViewSet, basename="basic")
- urlpatterns = [
- url(r'^api/', include(router.urls)),
- ]
+ urlpatterns = [url(r"^api/", include(router.urls))]
def setUp(self):
self.router = SimpleRouter()
@@ -134,51 +136,46 @@ class TestSimpleRouter(URLPatternsTestCase, TestCase):
# Get action routes (first two are list/detail)
routes = self.router.get_routes(BasicViewSet)[2:]
- assert routes[0].url == '^{prefix}/{lookup}/action1{trailing_slash}$'
- assert routes[0].mapping == {
- 'post': 'action1',
- }
+ assert routes[0].url == "^{prefix}/{lookup}/action1{trailing_slash}$"
+ assert routes[0].mapping == {"post": "action1"}
- assert routes[1].url == '^{prefix}/{lookup}/action2{trailing_slash}$'
- assert routes[1].mapping == {
- 'post': 'action2',
- 'delete': 'action2',
- }
+ assert routes[1].url == "^{prefix}/{lookup}/action2{trailing_slash}$"
+ assert routes[1].mapping == {"post": "action2", "delete": "action2"}
- assert routes[2].url == '^{prefix}/{lookup}/action3{trailing_slash}$'
- assert routes[2].mapping == {
- 'post': 'action3',
- 'delete': 'action3_delete',
- }
+ assert routes[2].url == "^{prefix}/{lookup}/action3{trailing_slash}$"
+ assert routes[2].mapping == {"post": "action3", "delete": "action3_delete"}
def test_multiple_action_handlers(self):
# Standard action
- response = self.client.post(reverse('basic-action3', args=[1]))
- assert response.data == {'post': '1'}
+ response = self.client.post(reverse("basic-action3", args=[1]))
+ assert response.data == {"post": "1"}
# Additional handler registered with MethodMapper
- response = self.client.delete(reverse('basic-action3', args=[1]))
- assert response.data == {'delete': '1'}
+ response = self.client.delete(reverse("basic-action3", args=[1]))
+ assert response.data == {"delete": "1"}
def test_register_after_accessing_urls(self):
- self.router.register(r'notes', NoteViewSet)
+ self.router.register(r"notes", NoteViewSet)
assert len(self.router.urls) == 2 # list and detail
- self.router.register(r'notes_bis', NoteViewSet)
+ self.router.register(r"notes_bis", NoteViewSet)
assert len(self.router.urls) == 4
class TestRootView(URLPatternsTestCase, TestCase):
urlpatterns = [
- url(r'^non-namespaced/', include(namespaced_router.urls)),
- url(r'^namespaced/', include((namespaced_router.urls, 'namespaced'), namespace='namespaced')),
+ url(r"^non-namespaced/", include(namespaced_router.urls)),
+ url(
+ r"^namespaced/",
+ include((namespaced_router.urls, "namespaced"), namespace="namespaced"),
+ ),
]
def test_retrieve_namespaced_root(self):
- response = self.client.get('/namespaced/')
+ response = self.client.get("/namespaced/")
assert response.data == {"example": "http://testserver/namespaced/example/"}
def test_retrieve_non_namespaced_root(self):
- response = self.client.get('/non-namespaced/')
+ response = self.client.get("/non-namespaced/")
assert response.data == {"example": "http://testserver/non-namespaced/example/"}
@@ -186,34 +183,51 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
"""
Ensure that custom lookup fields are correctly routed.
"""
+
urlpatterns = [
- url(r'^example/', include(notes_router.urls)),
- url(r'^example2/', include(kwarged_notes_router.urls)),
+ url(r"^example/", include(notes_router.urls)),
+ url(r"^example2/", include(kwarged_notes_router.urls)),
]
def setUp(self):
- RouterTestModel.objects.create(uuid='123', text='foo bar')
- RouterTestModel.objects.create(uuid='a b', text='baz qux')
+ RouterTestModel.objects.create(uuid="123", text="foo bar")
+ RouterTestModel.objects.create(uuid="a b", text="baz qux")
def test_custom_lookup_field_route(self):
detail_route = notes_router.urls[-1]
detail_url_pattern = get_regex_pattern(detail_route)
- assert '' in detail_url_pattern
+ assert "" in detail_url_pattern
def test_retrieve_lookup_field_list_view(self):
- response = self.client.get('/example/notes/')
+ response = self.client.get("/example/notes/")
assert response.data == [
- {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"},
- {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"},
+ {
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123",
+ "text": "foo bar",
+ },
+ {
+ "url": "http://testserver/example/notes/a%20b/",
+ "uuid": "a b",
+ "text": "baz qux",
+ },
]
def test_retrieve_lookup_field_detail_view(self):
- response = self.client.get('/example/notes/123/')
- assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
+ response = self.client.get("/example/notes/123/")
+ assert response.data == {
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123",
+ "text": "foo bar",
+ }
def test_retrieve_lookup_field_url_encoded_detail_view_(self):
- response = self.client.get('/example/notes/a%20b/')
- assert response.data == {"url": "http://testserver/example/notes/a%20b/", "uuid": "a b", "text": "baz qux"}
+ response = self.client.get("/example/notes/a%20b/")
+ assert response.data == {
+ "url": "http://testserver/example/notes/a%20b/",
+ "uuid": "a b",
+ "text": "baz qux",
+ }
class TestLookupValueRegex(TestCase):
@@ -221,49 +235,59 @@ class TestLookupValueRegex(TestCase):
Ensure the router honors lookup_value_regex when applied
to the viewset.
"""
+
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
queryset = RouterTestModel.objects.all()
- lookup_field = 'uuid'
- lookup_value_regex = '[0-9a-f]{32}'
+ lookup_field = "uuid"
+ lookup_value_regex = "[0-9a-f]{32}"
self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
+ self.router.register(r"notes", NoteViewSet)
self.urls = self.router.urls
def test_urls_limited_by_lookup_value_regex(self):
- expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$']
+ expected = ["^notes/$", "^notes/(?P[0-9a-f]{32})/$"]
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
-@override_settings(ROOT_URLCONF='tests.test_routers')
+@override_settings(ROOT_URLCONF="tests.test_routers")
class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
"""
Ensure the router honors lookup_url_kwarg.
Setup a deep lookup_field, but map it to a simple URL kwarg.
"""
+
urlpatterns = [
- url(r'^example/', include(notes_router.urls)),
- url(r'^example2/', include(kwarged_notes_router.urls)),
+ url(r"^example/", include(notes_router.urls)),
+ url(r"^example2/", include(kwarged_notes_router.urls)),
]
def setUp(self):
- RouterTestModel.objects.create(uuid='123', text='foo bar')
+ RouterTestModel.objects.create(uuid="123", text="foo bar")
def test_custom_lookup_url_kwarg_route(self):
detail_route = kwarged_notes_router.urls[-1]
detail_url_pattern = get_regex_pattern(detail_route)
- assert '^notes/(?P' in detail_url_pattern
+ assert "^notes/(?P" in detail_url_pattern
def test_retrieve_lookup_url_kwarg_detail_view(self):
- response = self.client.get('/example2/notes/fo/')
- assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
+ response = self.client.get("/example2/notes/fo/")
+ assert response.data == {
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123",
+ "text": "foo bar",
+ }
def test_retrieve_lookup_url_encoded_kwarg_detail_view(self):
- response = self.client.get('/example2/notes/foo%20bar/')
- assert response.data == {"url": "http://testserver/example/notes/123/", "uuid": "123", "text": "foo bar"}
+ response = self.client.get("/example2/notes/foo%20bar/")
+ assert response.data == {
+ "url": "http://testserver/example/notes/123/",
+ "uuid": "123",
+ "text": "foo bar",
+ }
class TestTrailingSlashIncluded(TestCase):
@@ -272,11 +296,11 @@ class TestTrailingSlashIncluded(TestCase):
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter()
- self.router.register(r'notes', NoteViewSet)
+ self.router.register(r"notes", NoteViewSet)
self.urls = self.router.urls
def test_urls_have_trailing_slash_by_default(self):
- expected = ['^notes/$', '^notes/(?P[^/.]+)/$']
+ expected = ["^notes/$", "^notes/(?P[^/.]+)/$"]
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
@@ -287,11 +311,11 @@ class TestTrailingSlashRemoved(TestCase):
queryset = RouterTestModel.objects.all()
self.router = SimpleRouter(trailing_slash=False)
- self.router.register(r'notes', NoteViewSet)
+ self.router.register(r"notes", NoteViewSet)
self.urls = self.router.urls
def test_urls_can_have_trailing_slash_removed(self):
- expected = ['^notes$', '^notes/(?P[^/.]+)$']
+ expected = ["^notes$", "^notes/(?P[^/.]+)$"]
for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx])
@@ -302,12 +326,12 @@ class TestNameableRoot(TestCase):
queryset = RouterTestModel.objects.all()
self.router = DefaultRouter()
- self.router.root_view_name = 'nameable-root'
- self.router.register(r'notes', NoteViewSet)
+ self.router.root_view_name = "nameable-root"
+ self.router.register(r"notes", NoteViewSet)
self.urls = self.router.urls
def test_router_has_custom_name(self):
- expected = 'nameable-root'
+ expected = "nameable-root"
assert expected == self.urls[-1].name
@@ -321,20 +345,20 @@ class TestActionKeywordArgs(TestCase):
class TestViewSet(viewsets.ModelViewSet):
permission_classes = []
- @action(methods=['post'], detail=True, permission_classes=[permissions.AllowAny])
+ @action(
+ methods=["post"], detail=True, permission_classes=[permissions.AllowAny]
+ )
def custom(self, request, *args, **kwargs):
- return Response({
- 'permission_classes': self.permission_classes
- })
+ return Response({"permission_classes": self.permission_classes})
self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, basename='test')
+ self.router.register(r"test", TestViewSet, basename="test")
self.view = self.router.urls[-1].callback
def test_action_kwargs(self):
- request = factory.post('/test/0/custom/')
+ request = factory.post("/test/0/custom/")
response = self.view(request)
- assert response.data == {'permission_classes': [permissions.AllowAny]}
+ assert response.data == {"permission_classes": [permissions.AllowAny]}
class TestActionAppliedToExistingRoute(TestCase):
@@ -345,15 +369,12 @@ class TestActionAppliedToExistingRoute(TestCase):
def test_exception_raised_when_action_applied_to_existing_route(self):
class TestViewSet(viewsets.ModelViewSet):
-
- @action(methods=['post'], detail=True)
+ @action(methods=["post"], detail=True)
def retrieve(self, request, *args, **kwargs):
- return Response({
- 'hello': 'world'
- })
+ return Response({"hello": "world"})
self.router = SimpleRouter()
- self.router.register(r'test', TestViewSet, basename='test')
+ self.router.register(r"test", TestViewSet, basename="test")
with pytest.raises(ImproperlyConfigured):
self.router.urls
@@ -361,31 +382,31 @@ class TestActionAppliedToExistingRoute(TestCase):
class DynamicListAndDetailViewSet(viewsets.ViewSet):
def list(self, request, *args, **kwargs):
- return Response({'method': 'list'})
+ return Response({"method": "list"})
- @action(methods=['post'], detail=False)
+ @action(methods=["post"], detail=False)
def list_route_post(self, request, *args, **kwargs):
- return Response({'method': 'action1'})
+ return Response({"method": "action1"})
- @action(methods=['post'], detail=True)
+ @action(methods=["post"], detail=True)
def detail_route_post(self, request, *args, **kwargs):
- return Response({'method': 'action2'})
+ return Response({"method": "action2"})
@action(detail=False)
def list_route_get(self, request, *args, **kwargs):
- return Response({'method': 'link1'})
+ return Response({"method": "link1"})
@action(detail=True)
def detail_route_get(self, request, *args, **kwargs):
- return Response({'method': 'link2'})
+ return Response({"method": "link2"})
@action(detail=False, url_path="list_custom-route")
def list_custom_route_get(self, request, *args, **kwargs):
- return Response({'method': 'link1'})
+ return Response({"method": "link1"})
@action(detail=True, url_path="detail_custom-route")
def detail_custom_route_get(self, request, *args, **kwargs):
- return Response({'method': 'link2'})
+ return Response({"method": "link2"})
class SubDynamicListAndDetailViewSet(DynamicListAndDetailViewSet):
@@ -398,31 +419,43 @@ class TestDynamicListAndDetailRouter(TestCase):
def _test_list_and_detail_route_decorators(self, viewset):
routes = self.router.get_routes(viewset)
- decorator_routes = [r for r in routes if not (r.name.endswith('-list') or r.name.endswith('-detail'))]
+ decorator_routes = [
+ r
+ for r in routes
+ if not (r.name.endswith("-list") or r.name.endswith("-detail"))
+ ]
- MethodNamesMap = namedtuple('MethodNamesMap', 'method_name url_path')
+ MethodNamesMap = namedtuple("MethodNamesMap", "method_name url_path")
# Make sure all these endpoints exist and none have been clobbered
- for i, endpoint in enumerate([MethodNamesMap('list_custom_route_get', 'list_custom-route'),
- MethodNamesMap('list_route_get', 'list_route_get'),
- MethodNamesMap('list_route_post', 'list_route_post'),
- MethodNamesMap('detail_custom_route_get', 'detail_custom-route'),
- MethodNamesMap('detail_route_get', 'detail_route_get'),
- MethodNamesMap('detail_route_post', 'detail_route_post')
- ]):
+ for i, endpoint in enumerate(
+ [
+ MethodNamesMap("list_custom_route_get", "list_custom-route"),
+ MethodNamesMap("list_route_get", "list_route_get"),
+ MethodNamesMap("list_route_post", "list_route_post"),
+ MethodNamesMap("detail_custom_route_get", "detail_custom-route"),
+ MethodNamesMap("detail_route_get", "detail_route_get"),
+ MethodNamesMap("detail_route_post", "detail_route_post"),
+ ]
+ ):
route = decorator_routes[i]
# check url listing
method_name = endpoint.method_name
url_path = endpoint.url_path
- if method_name.startswith('list_'):
- assert route.url == '^{{prefix}}/{0}{{trailing_slash}}$'.format(url_path)
+ if method_name.startswith("list_"):
+ assert route.url == "^{{prefix}}/{0}{{trailing_slash}}$".format(
+ url_path
+ )
else:
- assert route.url == '^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$'.format(url_path)
+ assert (
+ route.url
+ == "^{{prefix}}/{{lookup}}/{0}{{trailing_slash}}$".format(url_path)
+ )
# check method to function mapping
- if method_name.endswith('_post'):
- method_map = 'post'
+ if method_name.endswith("_post"):
+ method_map = "post"
else:
- method_map = 'get'
+ method_map = "get"
assert route.mapping[method_map] == method_name
def test_list_and_detail_route_decorators(self):
@@ -433,75 +466,77 @@ class TestDynamicListAndDetailRouter(TestCase):
class TestEmptyPrefix(URLPatternsTestCase, TestCase):
- urlpatterns = [
- url(r'^empty-prefix/', include(empty_prefix_router.urls)),
- ]
+ urlpatterns = [url(r"^empty-prefix/", include(empty_prefix_router.urls))]
def test_empty_prefix_list(self):
- response = self.client.get('/empty-prefix/')
+ response = self.client.get("/empty-prefix/")
assert response.status_code == 200
- assert json.loads(response.content.decode('utf-8')) == [{'uuid': '111', 'text': 'First'},
- {'uuid': '222', 'text': 'Second'}]
+ assert json.loads(response.content.decode("utf-8")) == [
+ {"uuid": "111", "text": "First"},
+ {"uuid": "222", "text": "Second"},
+ ]
def test_empty_prefix_detail(self):
- response = self.client.get('/empty-prefix/1/')
+ response = self.client.get("/empty-prefix/1/")
assert response.status_code == 200
- assert json.loads(response.content.decode('utf-8')) == {'uuid': '111', 'text': 'First'}
+ assert json.loads(response.content.decode("utf-8")) == {
+ "uuid": "111",
+ "text": "First",
+ }
class TestRegexUrlPath(URLPatternsTestCase, TestCase):
- urlpatterns = [
- url(r'^regex/', include(regex_url_path_router.urls)),
- ]
+ urlpatterns = [url(r"^regex/", include(regex_url_path_router.urls))]
def test_regex_url_path_list(self):
- kwarg = '1234'
- response = self.client.get('/regex/list/{}/'.format(kwarg))
+ kwarg = "1234"
+ response = self.client.get("/regex/list/{}/".format(kwarg))
assert response.status_code == 200
- assert json.loads(response.content.decode('utf-8')) == {'kwarg': kwarg}
+ assert json.loads(response.content.decode("utf-8")) == {"kwarg": kwarg}
def test_regex_url_path_detail(self):
- pk = '1'
- kwarg = '1234'
- response = self.client.get('/regex/{}/detail/{}/'.format(pk, kwarg))
+ pk = "1"
+ kwarg = "1234"
+ response = self.client.get("/regex/{}/detail/{}/".format(pk, kwarg))
assert response.status_code == 200
- assert json.loads(response.content.decode('utf-8')) == {'pk': pk, 'kwarg': kwarg}
+ assert json.loads(response.content.decode("utf-8")) == {
+ "pk": pk,
+ "kwarg": kwarg,
+ }
class TestViewInitkwargs(URLPatternsTestCase, TestCase):
- urlpatterns = [
- url(r'^example/', include(notes_router.urls)),
- ]
+ urlpatterns = [url(r"^example/", include(notes_router.urls))]
def test_suffix(self):
- match = resolve('/example/notes/')
+ match = resolve("/example/notes/")
initkwargs = match.func.initkwargs
- assert initkwargs['suffix'] == 'List'
+ assert initkwargs["suffix"] == "List"
def test_detail(self):
- match = resolve('/example/notes/')
+ match = resolve("/example/notes/")
initkwargs = match.func.initkwargs
- assert not initkwargs['detail']
+ assert not initkwargs["detail"]
def test_basename(self):
- match = resolve('/example/notes/')
+ match = resolve("/example/notes/")
initkwargs = match.func.initkwargs
- assert initkwargs['basename'] == 'routertestmodel'
+ assert initkwargs["basename"] == "routertestmodel"
class TestBaseNameRename(TestCase):
-
def test_base_name_and_basename_assertion(self):
router = SimpleRouter()
msg = "Do not provide both the `basename` and `base_name` arguments."
- with warnings.catch_warnings(record=True) as w, \
- self.assertRaisesMessage(AssertionError, msg):
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, 'mock', base_name='mock')
+ with warnings.catch_warnings(record=True) as w, self.assertRaisesMessage(
+ AssertionError, msg
+ ):
+ warnings.simplefilter("always")
+ router.register("mock", MockViewSet, "mock", base_name="mock")
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1
@@ -511,50 +546,44 @@ class TestBaseNameRename(TestCase):
router = SimpleRouter()
with pytest.warns(RemovedInDRF311Warning) as w:
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, base_name='mock')
+ warnings.simplefilter("always")
+ router.register("mock", MockViewSet, base_name="mock")
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
assert len(w) == 1
assert str(w[0].message) == msg
- assert router.registry == [
- ('mock', MockViewSet, 'mock'),
- ]
+ assert router.registry == [("mock", MockViewSet, "mock")]
def test_basename_argument_no_warnings(self):
router = SimpleRouter()
with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
- router.register('mock', MockViewSet, basename='mock')
+ warnings.simplefilter("always")
+ router.register("mock", MockViewSet, basename="mock")
assert len(w) == 0
- assert router.registry == [
- ('mock', MockViewSet, 'mock'),
- ]
+ assert router.registry == [("mock", MockViewSet, "mock")]
def test_get_default_base_name_deprecation(self):
msg = "`CustomRouter.get_default_base_name` method should be renamed `get_default_basename`."
# Class definition should raise a warning
with pytest.warns(RemovedInDRF311Warning) as w:
- warnings.simplefilter('always')
+ warnings.simplefilter("always")
class CustomRouter(SimpleRouter):
def get_default_base_name(self, viewset):
- return 'foo'
+ return "foo"
assert len(w) == 1
assert str(w[0].message) == msg
# Deprecated method implementation should still be called
with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter('always')
+ warnings.simplefilter("always")
router = CustomRouter()
- router.register('mock', MockViewSet)
+ router.register("mock", MockViewSet)
assert len(w) == 0
- assert router.registry == [
- ('mock', MockViewSet, 'foo'),
- ]
+ assert router.registry == [("mock", MockViewSet, "foo")]
diff --git a/tests/test_schemas.py b/tests/test_schemas.py
index 3cb9e0cda..e1bfd347a 100644
--- a/tests/test_schemas.py
+++ b/tests/test_schemas.py
@@ -6,15 +6,16 @@ from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.test import TestCase, override_settings
-from rest_framework import (
- filters, generics, pagination, permissions, serializers
-)
+from rest_framework import filters, generics, pagination, permissions, serializers
from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path
from rest_framework.decorators import action, api_view, schema
from rest_framework.request import Request
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import (
- AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
+ AutoSchema,
+ ManualSchema,
+ SchemaGenerator,
+ get_schema_view,
)
from rest_framework.schemas.generators import EndpointEnumerator
from rest_framework.schemas.inspectors import field_to_schema
@@ -26,6 +27,7 @@ from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import BasicModel, ForeignKeySource, ManyToManySource
+
factory = APIRequestFactory()
@@ -36,7 +38,7 @@ class MockUser(object):
class ExamplePagination(pagination.PageNumberPagination):
page_size = 100
- page_size_query_param = 'page_size'
+ page_size_query_param = "page_size"
class EmptySerializer(serializers.Serializer):
@@ -44,10 +46,10 @@ class EmptySerializer(serializers.Serializer):
class ExampleSerializer(serializers.Serializer):
- a = serializers.CharField(required=True, help_text='A field description')
+ a = serializers.CharField(required=True, help_text="A field description")
b = serializers.CharField(required=False)
read_only = serializers.CharField(read_only=True)
- hidden = serializers.HiddenField(default='hello')
+ hidden = serializers.HiddenField(default="hello")
class AnotherSerializerWithDictField(serializers.Serializer):
@@ -70,21 +72,25 @@ class ExampleViewSet(ModelViewSet):
filter_backends = [filters.OrderingFilter]
serializer_class = ExampleSerializer
- @action(methods=['post'], detail=True, serializer_class=AnotherSerializer)
+ @action(methods=["post"], detail=True, serializer_class=AnotherSerializer)
def custom_action(self, request, pk):
"""
A description of custom action.
"""
raise NotImplementedError
- @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithDictField)
+ @action(
+ methods=["post"], detail=True, serializer_class=AnotherSerializerWithDictField
+ )
def custom_action_with_dict_field(self, request, pk):
"""
A custom action using a dict field in the serializer.
"""
raise NotImplementedError
- @action(methods=['post'], detail=True, serializer_class=AnotherSerializerWithListFields)
+ @action(
+ methods=["post"], detail=True, serializer_class=AnotherSerializerWithListFields
+ )
def custom_action_with_list_fields(self, request, pk):
"""
A custom action using both list field and list serializer in the serializer.
@@ -95,7 +101,7 @@ class ExampleViewSet(ModelViewSet):
def custom_list_action(self, request):
raise NotImplementedError
- @action(methods=['post', 'get'], detail=False, serializer_class=EmptySerializer)
+ @action(methods=["post", "get"], detail=False, serializer_class=EmptySerializer)
def custom_list_action_multiple_methods(self, request):
"""Custom description."""
raise NotImplementedError
@@ -114,7 +120,7 @@ class ExampleViewSet(ModelViewSet):
assert self.action
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
- @action(methods=['get', 'post'], detail=False)
+ @action(methods=["get", "post"], detail=False)
def documented_custom_action(self, request):
"""
get:
@@ -134,225 +140,438 @@ class ExampleViewSet(ModelViewSet):
if coreapi:
- schema_view = get_schema_view(title='Example API')
+ schema_view = get_schema_view(title="Example API")
else:
+
def schema_view(request):
pass
+
router = DefaultRouter()
-router.register('example', ExampleViewSet, basename='example')
-urlpatterns = [
- url(r'^$', schema_view),
- url(r'^', include(router.urls))
-]
+router.register("example", ExampleViewSet, basename="example")
+urlpatterns = [url(r"^$", schema_view), url(r"^", include(router.urls))]
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
-@override_settings(ROOT_URLCONF='tests.test_schemas')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
+@override_settings(ROOT_URLCONF="tests.test_schemas")
class TestRouterGeneratedSchema(TestCase):
def test_anonymous_request(self):
client = APIClient()
- response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
+ response = client.get("/", HTTP_ACCEPT="application/coreapi+json")
assert response.status_code == 200
expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
+ url="http://testserver/",
+ title="Example API",
content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
+ "example": {
+ "list": coreapi.Link(
+ url="/example/",
+ action="get",
fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "page",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page",
+ description="A page number within the paginated result set.",
+ ),
+ ),
+ coreapi.Field(
+ "page_size",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page size",
+ description="Number of results to return per page.",
+ ),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'custom_list_action': coreapi.Link(
- url='/example/custom_list_action/',
- action='get'
+ "custom_list_action": coreapi.Link(
+ url="/example/custom_list_action/", action="get"
),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
+ "custom_list_action_multiple_methods": {
+ "read": coreapi.Link(
+ url="/example/custom_list_action_multiple_methods/",
+ action="get",
+ description="Custom description.",
)
},
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
+ "documented_custom_action": {
+ "read": coreapi.Link(
+ url="/example/documented_custom_action/",
+ action="get",
+ description="A description of the get method on the custom action.",
)
},
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
+ "read": coreapi.Link(
+ url="/example/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
+ ),
}
- }
+ },
)
assert response.data == expected
def test_authenticated_request(self):
client = APIClient()
client.force_authenticate(MockUser())
- response = client.get('/', HTTP_ACCEPT='application/coreapi+json')
+ response = client.get("/", HTTP_ACCEPT="application/coreapi+json")
assert response.status_code == 200
expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
+ url="http://testserver/",
+ title="Example API",
content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
+ "example": {
+ "list": coreapi.Link(
+ url="/example/",
+ action="get",
fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "page",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page",
+ description="A page number within the paginated result set.",
+ ),
+ ),
+ coreapi.Field(
+ "page_size",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page size",
+ description="Number of results to return per page.",
+ ),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
+ "create": coreapi.Link(
+ url="/example/",
+ action="post",
+ encoding="application/json",
fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.String(
+ title="A", description="A field description"
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="B"),
+ ),
+ ],
),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
+ "read": coreapi.Link(
+ url="/example/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'custom_action': coreapi.Link(
- url='/example/{id}/custom_action/',
- action='post',
- encoding='application/json',
- description='A description of custom action.',
+ "custom_action": coreapi.Link(
+ url="/example/{id}/custom_action/",
+ action="post",
+ encoding="application/json",
+ description="A description of custom action.",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('c', required=True, location='form', schema=coreschema.String(title='C')),
- coreapi.Field('d', required=False, location='form', schema=coreschema.String(title='D')),
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "c",
+ required=True,
+ location="form",
+ schema=coreschema.String(title="C"),
+ ),
+ coreapi.Field(
+ "d",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="D"),
+ ),
+ ],
),
- 'custom_action_with_dict_field': coreapi.Link(
- url='/example/{id}/custom_action_with_dict_field/',
- action='post',
- encoding='application/json',
- description='A custom action using a dict field in the serializer.',
+ "custom_action_with_dict_field": coreapi.Link(
+ url="/example/{id}/custom_action_with_dict_field/",
+ action="post",
+ encoding="application/json",
+ description="A custom action using a dict field in the serializer.",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.Object(title='A')),
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.Object(title="A"),
+ ),
+ ],
),
- 'custom_action_with_list_fields': coreapi.Link(
- url='/example/{id}/custom_action_with_list_fields/',
- action='post',
- encoding='application/json',
- description='A custom action using both list field and list serializer in the serializer.',
+ "custom_action_with_list_fields": coreapi.Link(
+ url="/example/{id}/custom_action_with_list_fields/",
+ action="post",
+ encoding="application/json",
+ description="A custom action using both list field and list serializer in the serializer.",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.Array(title='A', items=coreschema.Integer())),
- coreapi.Field('b', required=True, location='form', schema=coreschema.Array(title='B', items=coreschema.String())),
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.Array(
+ title="A", items=coreschema.Integer()
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=True,
+ location="form",
+ schema=coreschema.Array(
+ title="B", items=coreschema.String()
+ ),
+ ),
+ ],
),
- 'custom_list_action': coreapi.Link(
- url='/example/custom_list_action/',
- action='get'
+ "custom_list_action": coreapi.Link(
+ url="/example/custom_list_action/", action="get"
),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
+ "custom_list_action_multiple_methods": {
+ "read": coreapi.Link(
+ url="/example/custom_list_action_multiple_methods/",
+ action="get",
+ description="Custom description.",
),
- 'create': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='post',
- description='Custom description.',
+ "create": coreapi.Link(
+ url="/example/custom_list_action_multiple_methods/",
+ action="post",
+ description="Custom description.",
),
- 'delete': coreapi.Link(
- url='/example/custom_list_action_multiple_methods/',
- action='delete',
- description='Deletion description.',
+ "delete": coreapi.Link(
+ url="/example/custom_list_action_multiple_methods/",
+ action="delete",
+ description="Deletion description.",
),
},
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
+ "documented_custom_action": {
+ "read": coreapi.Link(
+ url="/example/documented_custom_action/",
+ action="get",
+ description="A description of the get method on the custom action.",
),
- 'create': coreapi.Link(
- url='/example/documented_custom_action/',
- action='post',
- description='A description of the post method on the custom action.',
- encoding='application/json',
+ "create": coreapi.Link(
+ url="/example/documented_custom_action/",
+ action="post",
+ description="A description of the post method on the custom action.",
+ encoding="application/json",
fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.String(
+ title="A", description="A field description"
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="B"),
+ ),
+ ],
),
- 'update': coreapi.Link(
- url='/example/documented_custom_action/',
- action='put',
- description='A description of the put method on the custom action from mapping.',
- encoding='application/json',
+ "update": coreapi.Link(
+ url="/example/documented_custom_action/",
+ action="put",
+ description="A description of the put method on the custom action from mapping.",
+ encoding="application/json",
fields=[
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B'))
- ]
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.String(
+ title="A", description="A field description"
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="B"),
+ ),
+ ],
),
},
- 'update': coreapi.Link(
- url='/example/{id}/',
- action='put',
- encoding='application/json',
+ "update": coreapi.Link(
+ url="/example/{id}/",
+ action="put",
+ encoding="application/json",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=True, location='form', schema=coreschema.String(title='A', description=('A field description'))),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "a",
+ required=True,
+ location="form",
+ schema=coreschema.String(
+ title="A", description=("A field description")
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="B"),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'partial_update': coreapi.Link(
- url='/example/{id}/',
- action='patch',
- encoding='application/json',
+ "partial_update": coreapi.Link(
+ url="/example/{id}/",
+ action="patch",
+ encoding="application/json",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('a', required=False, location='form', schema=coreschema.String(title='A', description='A field description')),
- coreapi.Field('b', required=False, location='form', schema=coreschema.String(title='B')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "a",
+ required=False,
+ location="form",
+ schema=coreschema.String(
+ title="A", description="A field description"
+ ),
+ ),
+ coreapi.Field(
+ "b",
+ required=False,
+ location="form",
+ schema=coreschema.String(title="B"),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'delete': coreapi.Link(
- url='/example/{id}/',
- action='delete',
+ "delete": coreapi.Link(
+ url="/example/{id}/",
+ action="delete",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
+ ),
}
- }
+ },
)
assert response.data == expected
class DenyAllUsingHttp404(permissions.BasePermission):
-
def has_permission(self, request, view):
raise Http404()
@@ -361,7 +580,6 @@ class DenyAllUsingHttp404(permissions.BasePermission):
class DenyAllUsingPermissionDenied(permissions.BasePermission):
-
def has_permission(self, request, view):
raise PermissionDenied()
@@ -379,7 +597,7 @@ class PermissionDeniedExampleViewSet(ExampleViewSet):
class MethodLimitedViewSet(ExampleViewSet):
permission_classes = []
- http_method_names = ['get', 'head', 'options']
+ http_method_names = ["get", "head", "options"]
class ExampleListView(APIView):
@@ -399,118 +617,122 @@ class ExampleDetailView(APIView):
pass
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGenerator(TestCase):
def setUp(self):
self.patterns = [
- url(r'^example/?$', ExampleListView.as_view()),
- url(r'^example/(?P\d+)/?$', ExampleDetailView.as_view()),
- url(r'^example/(?P\d+)/sub/?$', ExampleDetailView.as_view()),
+ url(r"^example/?$", ExampleListView.as_view()),
+ url(r"^example/(?P\d+)/?$", ExampleDetailView.as_view()),
+ url(r"^example/(?P\d+)/sub/?$", ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
"""
Ensure that schema generation works for APIView classes.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
+ "example": {
+ "create": coreapi.Link(url="/example/", action="post", fields=[]),
+ "list": coreapi.Link(url="/example/", action="get", fields=[]),
+ "read": coreapi.Link(
+ url="/example/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'sub': {
- 'list': coreapi.Link(
- url='/example/{id}/sub/',
- action='get',
+ "sub": {
+ "list": coreapi.Link(
+ url="/example/{id}/sub/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
)
- }
+ },
}
- }
+ },
)
assert schema == expected
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
-@unittest.skipUnless(path, 'needs Django 2')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
+@unittest.skipUnless(path, "needs Django 2")
class TestSchemaGeneratorDjango2(TestCase):
def setUp(self):
self.patterns = [
- path('example/', ExampleListView.as_view()),
- path('example//', ExampleDetailView.as_view()),
- path('example//sub/', ExampleDetailView.as_view()),
+ path("example/", ExampleListView.as_view()),
+ path("example//", ExampleDetailView.as_view()),
+ path("example//sub/", ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
"""
Ensure that schema generation works for APIView classes.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/example/{id}/',
- action='get',
+ "example": {
+ "create": coreapi.Link(url="/example/", action="post", fields=[]),
+ "list": coreapi.Link(url="/example/", action="get", fields=[]),
+ "read": coreapi.Link(
+ url="/example/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'sub': {
- 'list': coreapi.Link(
- url='/example/{id}/sub/',
- action='get',
+ "sub": {
+ "list": coreapi.Link(
+ url="/example/{id}/sub/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
)
- }
+ },
}
- }
+ },
)
assert schema == expected
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGeneratorNotAtRoot(TestCase):
def setUp(self):
self.patterns = [
- url(r'^api/v1/example/?$', ExampleListView.as_view()),
- url(r'^api/v1/example/(?P\d+)/?$', ExampleDetailView.as_view()),
- url(r'^api/v1/example/(?P\d+)/sub/?$', ExampleDetailView.as_view()),
+ url(r"^api/v1/example/?$", ExampleListView.as_view()),
+ url(r"^api/v1/example/(?P\d+)/?$", ExampleDetailView.as_view()),
+ url(r"^api/v1/example/(?P\d+)/sub/?$", ExampleDetailView.as_view()),
]
def test_schema_for_regular_views(self):
@@ -518,118 +740,158 @@ class TestSchemaGeneratorNotAtRoot(TestCase):
Ensure that schema generation with an API that is not at the URL
root continues to use correct structure for link keys.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'example': {
- 'create': coreapi.Link(
- url='/api/v1/example/',
- action='post',
- fields=[]
+ "example": {
+ "create": coreapi.Link(
+ url="/api/v1/example/", action="post", fields=[]
),
- 'list': coreapi.Link(
- url='/api/v1/example/',
- action='get',
- fields=[]
+ "list": coreapi.Link(
+ url="/api/v1/example/", action="get", fields=[]
),
- 'read': coreapi.Link(
- url='/api/v1/example/{id}/',
- action='get',
+ "read": coreapi.Link(
+ url="/api/v1/example/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'sub': {
- 'list': coreapi.Link(
- url='/api/v1/example/{id}/sub/',
- action='get',
+ "sub": {
+ "list": coreapi.Link(
+ url="/api/v1/example/{id}/sub/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
)
- }
+ },
}
- }
+ },
)
assert schema == expected
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGeneratorWithMethodLimitedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
- router.register('example1', MethodLimitedViewSet, basename='example1')
- self.patterns = [
- url(r'^', include(router.urls))
- ]
+ router.register("example1", MethodLimitedViewSet, basename="example1")
+ self.patterns = [url(r"^", include(router.urls))]
def test_schema_for_regular_views(self):
"""
Ensure that schema generation works for ViewSet classes
with method limitation by Django CBV's http_method_names attribute
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- request = factory.get('/example1/')
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
+ request = factory.get("/example1/")
schema = generator.get_schema(Request(request))
expected = coreapi.Document(
- url='http://testserver/example1/',
- title='Example API',
+ url="http://testserver/example1/",
+ title="Example API",
content={
- 'example1': {
- 'list': coreapi.Link(
- url='/example1/',
- action='get',
+ "example1": {
+ "list": coreapi.Link(
+ url="/example1/",
+ action="get",
fields=[
- coreapi.Field('page', required=False, location='query', schema=coreschema.Integer(title='Page', description='A page number within the paginated result set.')),
- coreapi.Field('page_size', required=False, location='query', schema=coreschema.Integer(title='Page size', description='Number of results to return per page.')),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
+ coreapi.Field(
+ "page",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page",
+ description="A page number within the paginated result set.",
+ ),
+ ),
+ coreapi.Field(
+ "page_size",
+ required=False,
+ location="query",
+ schema=coreschema.Integer(
+ title="Page size",
+ description="Number of results to return per page.",
+ ),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
),
- 'custom_list_action': coreapi.Link(
- url='/example1/custom_list_action/',
- action='get'
+ "custom_list_action": coreapi.Link(
+ url="/example1/custom_list_action/", action="get"
),
- 'custom_list_action_multiple_methods': {
- 'read': coreapi.Link(
- url='/example1/custom_list_action_multiple_methods/',
- action='get',
- description='Custom description.',
+ "custom_list_action_multiple_methods": {
+ "read": coreapi.Link(
+ url="/example1/custom_list_action_multiple_methods/",
+ action="get",
+ description="Custom description.",
)
},
- 'documented_custom_action': {
- 'read': coreapi.Link(
- url='/example1/documented_custom_action/',
- action='get',
- description='A description of the get method on the custom action.',
- ),
+ "documented_custom_action": {
+ "read": coreapi.Link(
+ url="/example1/documented_custom_action/",
+ action="get",
+ description="A description of the get method on the custom action.",
+ )
},
- 'read': coreapi.Link(
- url='/example1/{id}/',
- action='get',
+ "read": coreapi.Link(
+ url="/example1/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
- coreapi.Field('ordering', required=False, location='query', schema=coreschema.String(title='Ordering', description='Which field to use when ordering the results.'))
- ]
- )
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ ),
+ coreapi.Field(
+ "ordering",
+ required=False,
+ location="query",
+ schema=coreschema.String(
+ title="Ordering",
+ description="Which field to use when ordering the results.",
+ ),
+ ),
+ ],
+ ),
}
- }
+ },
)
assert schema == expected
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
def setUp(self):
router = DefaultRouter()
- router.register('example1', Http404ExampleViewSet, basename='example1')
- router.register('example2', PermissionDeniedExampleViewSet, basename='example2')
+ router.register("example1", Http404ExampleViewSet, basename="example1")
+ router.register("example2", PermissionDeniedExampleViewSet, basename="example2")
self.patterns = [
- url('^example/?$', ExampleListView.as_view()),
- url(r'^', include(router.urls))
+ url("^example/?$", ExampleListView.as_view()),
+ url(r"^", include(router.urls)),
]
def test_schema_for_regular_views(self):
@@ -637,21 +899,17 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
Ensure that schema generation works for ViewSet classes
with permission classes raising exceptions.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
- request = factory.get('/')
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
+ request = factory.get("/")
schema = generator.get_schema(Request(request))
expected = coreapi.Document(
- url='http://testserver/',
- title='Example API',
+ url="http://testserver/",
+ title="Example API",
content={
- 'example': {
- 'list': coreapi.Link(
- url='/example/',
- action='get',
- fields=[]
- ),
- },
- }
+ "example": {
+ "list": coreapi.Link(url="/example/", action="get", fields=[])
+ }
+ },
)
assert schema == expected
@@ -659,7 +917,7 @@ class TestSchemaGeneratorWithRestrictedViewSets(TestCase):
class ForeignKeySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
- fields = ('id', 'name', 'target')
+ fields = ("id", "name", "target")
class ForeignKeySourceView(generics.CreateAPIView):
@@ -667,36 +925,46 @@ class ForeignKeySourceView(generics.CreateAPIView):
serializer_class = ForeignKeySourceSerializer
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGeneratorWithForeignKey(TestCase):
def setUp(self):
- self.patterns = [
- url(r'^example/?$', ForeignKeySourceView.as_view()),
- ]
+ self.patterns = [url(r"^example/?$", ForeignKeySourceView.as_view())]
def test_schema_for_regular_views(self):
"""
Ensure that AutoField foreign keys are output as Integer.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
+ "example": {
+ "create": coreapi.Link(
+ url="/example/",
+ action="post",
+ encoding="application/json",
fields=[
- coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')),
- coreapi.Field('target', required=True, location='form', schema=coreschema.Integer(description='Target', title='Target')),
- ]
+ coreapi.Field(
+ "name",
+ required=True,
+ location="form",
+ schema=coreschema.String(title="Name"),
+ ),
+ coreapi.Field(
+ "target",
+ required=True,
+ location="form",
+ schema=coreschema.Integer(
+ description="Target", title="Target"
+ ),
+ ),
+ ],
)
}
- }
+ },
)
assert schema == expected
@@ -704,7 +972,7 @@ class TestSchemaGeneratorWithForeignKey(TestCase):
class ManyToManySourceSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManySource
- fields = ('id', 'name', 'targets')
+ fields = ("id", "name", "targets")
class ManyToManySourceView(generics.CreateAPIView):
@@ -712,61 +980,70 @@ class ManyToManySourceView(generics.CreateAPIView):
serializer_class = ManyToManySourceSerializer
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class TestSchemaGeneratorWithManyToMany(TestCase):
def setUp(self):
- self.patterns = [
- url(r'^example/?$', ManyToManySourceView.as_view()),
- ]
+ self.patterns = [url(r"^example/?$", ManyToManySourceView.as_view())]
def test_schema_for_regular_views(self):
"""
Ensure that AutoField many to many fields are output as Integer.
"""
- generator = SchemaGenerator(title='Example API', patterns=self.patterns)
+ generator = SchemaGenerator(title="Example API", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'example': {
- 'create': coreapi.Link(
- url='/example/',
- action='post',
- encoding='application/json',
+ "example": {
+ "create": coreapi.Link(
+ url="/example/",
+ action="post",
+ encoding="application/json",
fields=[
- coreapi.Field('name', required=True, location='form', schema=coreschema.String(title='Name')),
- coreapi.Field('targets', required=True, location='form', schema=coreschema.Array(title='Targets', items=coreschema.Integer())),
- ]
+ coreapi.Field(
+ "name",
+ required=True,
+ location="form",
+ schema=coreschema.String(title="Name"),
+ ),
+ coreapi.Field(
+ "targets",
+ required=True,
+ location="form",
+ schema=coreschema.Array(
+ title="Targets", items=coreschema.Integer()
+ ),
+ ),
+ ],
)
}
- }
+ },
)
assert schema == expected
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class Test4605Regression(TestCase):
def test_4605_regression(self):
generator = SchemaGenerator()
- prefix = generator.determine_path_prefix([
- '/api/v1/items/',
- '/auth/convert-token/'
- ])
- assert prefix == '/'
+ prefix = generator.determine_path_prefix(
+ ["/api/v1/items/", "/auth/convert-token/"]
+ )
+ assert prefix == "/"
class CustomViewInspector(AutoSchema):
"""A dummy AutoSchema subclass"""
+
pass
class TestAutoSchema(TestCase):
-
def test_apiview_schema_descriptor(self):
view = APIView()
- assert hasattr(view, 'schema')
+ assert hasattr(view, "schema")
assert isinstance(view.schema, AutoSchema)
def test_set_custom_inspector_class_on_view(self):
@@ -777,16 +1054,22 @@ class TestAutoSchema(TestCase):
assert isinstance(view.schema, CustomViewInspector)
def test_set_custom_inspector_class_via_settings(self):
- with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'tests.test_schemas.CustomViewInspector'}):
+ with override_settings(
+ REST_FRAMEWORK={
+ "DEFAULT_SCHEMA_CLASS": "tests.test_schemas.CustomViewInspector"
+ }
+ ):
view = APIView()
assert isinstance(view.schema, CustomViewInspector)
def test_get_link_requires_instance(self):
descriptor = APIView.schema # Accessed from class
with pytest.raises(AssertionError):
- descriptor.get_link(None, None, None) # ???: Do the dummy arguments require a tighter assert?
+ descriptor.get_link(
+ None, None, None
+ ) # ???: Do the dummy arguments require a tighter assert?
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_update_fields(self):
"""
That updating fields by-name helper is correct
@@ -797,78 +1080,91 @@ class TestAutoSchema(TestCase):
fields = []
# Adds a field...
- fields = schema.update_fields(fields, [
- coreapi.Field(
- "my_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ])
+ fields = schema.update_fields(
+ fields,
+ [
+ coreapi.Field(
+ "my_field",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ )
assert len(fields) == 1
assert fields[0].name == "my_field"
# Replaces a field...
- fields = schema.update_fields(fields, [
- coreapi.Field(
- "my_field",
- required=False,
- location="path",
- schema=coreschema.String()
- ),
- ])
+ fields = schema.update_fields(
+ fields,
+ [
+ coreapi.Field(
+ "my_field",
+ required=False,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ )
assert len(fields) == 1
assert fields[0].required is False
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_get_manual_fields(self):
"""That get_manual_fields is applied during get_link"""
class CustomView(APIView):
- schema = AutoSchema(manual_fields=[
- coreapi.Field(
- "my_extra_field",
- required=True,
- location="path",
- schema=coreschema.String()
- ),
- ])
+ schema = AutoSchema(
+ manual_fields=[
+ coreapi.Field(
+ "my_extra_field",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ]
+ )
view = CustomView()
- link = view.schema.get_link('/a/url/{id}/', 'GET', '')
+ link = view.schema.get_link("/a/url/{id}/", "GET", "")
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_viewset_action_with_schema(self):
class CustomViewSet(GenericViewSet):
- @action(detail=True, schema=AutoSchema(manual_fields=[
- coreapi.Field(
- "my_extra_field",
- required=True,
- location="path",
- schema=coreschema.String()
+ @action(
+ detail=True,
+ schema=AutoSchema(
+ manual_fields=[
+ coreapi.Field(
+ "my_extra_field",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ]
),
- ]))
+ )
def extra_action(self, pk, **kwargs):
pass
router = SimpleRouter()
- router.register(r'detail', CustomViewSet, basename='detail')
+ router.register(r"detail", CustomViewSet, basename="detail")
generator = SchemaGenerator()
- view = generator.create_view(router.urls[0].callback, 'GET')
- link = view.schema.get_link('/a/url/{id}/', 'GET', '')
+ view = generator.create_view(router.urls[0].callback, "GET")
+ link = view.schema.get_link("/a/url/{id}/", "GET", "")
fields = link.fields
assert len(fields) == 2
assert "my_extra_field" in [f.name for f in fields]
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_viewset_action_with_null_schema(self):
class CustomViewSet(GenericViewSet):
@action(detail=True, schema=None)
@@ -876,17 +1172,17 @@ class TestAutoSchema(TestCase):
pass
router = SimpleRouter()
- router.register(r'detail', CustomViewSet, basename='detail')
+ router.register(r"detail", CustomViewSet, basename="detail")
generator = SchemaGenerator()
- view = generator.create_view(router.urls[0].callback, 'GET')
+ view = generator.create_view(router.urls[0].callback, "GET")
assert view.schema is None
- @pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+ @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_view_with_manual_schema(self):
- path = '/example'
- method = 'get'
+ path = "/example"
+ method = "get"
base_url = None
fields = [
@@ -894,19 +1190,19 @@ class TestAutoSchema(TestCase):
"first_field",
required=True,
location="path",
- schema=coreschema.String()
+ schema=coreschema.String(),
),
coreapi.Field(
"second_field",
required=True,
location="path",
- schema=coreschema.String()
+ schema=coreschema.String(),
),
coreapi.Field(
"third_field",
required=True,
location="path",
- schema=coreschema.String()
+ schema=coreschema.String(),
),
]
description = "A test endpoint"
@@ -916,54 +1212,54 @@ class TestAutoSchema(TestCase):
ManualSchema takes list of fields for endpoint.
- Provides url and action, which are always dynamic
"""
+
schema = ManualSchema(fields, description)
expected = coreapi.Link(
- url=path,
- action=method,
- fields=fields,
- description=description
+ url=path, action=method, fields=fields, description=description
)
view = CustomView()
link = view.schema.get_link(path, method, base_url)
assert link == expected
- @unittest.skipUnless(coreschema, 'coreschema is not installed')
+ @unittest.skipUnless(coreschema, "coreschema is not installed")
def test_field_to_schema(self):
- label = 'Test label'
- help_text = 'This is a helpful test text'
+ label = "Test label"
+ help_text = "This is a helpful test text"
cases = [
# tuples are ([field], [expected schema])
# TODO: Add remaining cases
(
serializers.BooleanField(label=label, help_text=help_text),
- coreschema.Boolean(title=label, description=help_text)
+ coreschema.Boolean(title=label, description=help_text),
),
(
serializers.DecimalField(1000, 1000, label=label, help_text=help_text),
- coreschema.Number(title=label, description=help_text)
+ coreschema.Number(title=label, description=help_text),
),
(
serializers.FloatField(label=label, help_text=help_text),
- coreschema.Number(title=label, description=help_text)
+ coreschema.Number(title=label, description=help_text),
),
(
serializers.IntegerField(label=label, help_text=help_text),
- coreschema.Integer(title=label, description=help_text)
+ coreschema.Integer(title=label, description=help_text),
),
(
serializers.DateField(label=label, help_text=help_text),
- coreschema.String(title=label, description=help_text, format='date')
+ coreschema.String(title=label, description=help_text, format="date"),
),
(
serializers.DateTimeField(label=label, help_text=help_text),
- coreschema.String(title=label, description=help_text, format='date-time')
+ coreschema.String(
+ title=label, description=help_text, format="date-time"
+ ),
),
(
serializers.JSONField(label=label, help_text=help_text),
- coreschema.Object(title=label, description=help_text)
+ coreschema.Object(title=label, description=help_text),
),
]
@@ -1001,7 +1297,7 @@ def test_docstring_is_not_stripped_by_get_description():
view = ExampleDocstringAPIView()
schema = view.schema
- descr = schema.get_description('example', 'get')
+ descr = schema.get_description("example", "get")
# the first and last character are '\n' correctly removed by get_description
assert descr == formatting.dedent(ExampleDocstringAPIView.__doc__[1:][:-1])
@@ -1014,42 +1310,42 @@ class ExcludedAPIView(APIView):
pass
-@api_view(['GET'])
+@api_view(["GET"])
@schema(None)
def excluded_fbv(request):
pass
-@api_view(['GET'])
+@api_view(["GET"])
def included_fbv(request):
pass
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class SchemaGenerationExclusionTests(TestCase):
def setUp(self):
self.patterns = [
- url('^excluded-cbv/$', ExcludedAPIView.as_view()),
- url('^excluded-fbv/$', excluded_fbv),
- url('^included-fbv/$', included_fbv),
+ url("^excluded-cbv/$", ExcludedAPIView.as_view()),
+ url("^excluded-fbv/$", excluded_fbv),
+ url("^included-fbv/$", included_fbv),
]
def test_schema_generator_excludes_correctly(self):
"""Schema should not include excluded views"""
- generator = SchemaGenerator(title='Exclusions', patterns=self.patterns)
+ generator = SchemaGenerator(title="Exclusions", patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Exclusions',
+ url="",
+ title="Exclusions",
content={
- 'included-fbv': {
- 'list': coreapi.Link(url='/included-fbv/', action='get')
+ "included-fbv": {
+ "list": coreapi.Link(url="/included-fbv/", action="get")
}
- }
+ },
)
assert len(schema.data) == 1
- assert 'included-fbv' in schema.data
+ assert "included-fbv" in schema.data
assert schema == expected
def test_endpoint_enumerator_excludes_correctly(self):
@@ -1059,20 +1355,23 @@ class SchemaGenerationExclusionTests(TestCase):
assert len(endpoints) == 1
path, method, callback = endpoints[0]
- assert path == '/included-fbv/'
+ assert path == "/included-fbv/"
def test_should_include_endpoint_excludes_correctly(self):
"""This is the specific method that should handle the exclusion"""
inspector = EndpointEnumerator(self.patterns)
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
- pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback)
- for pattern in self.patterns]
-
- should_include = [
- inspector.should_include_endpoint(*pair) for pair in pairs
+ pairs = [
+ (
+ inspector.get_path_from_regex(get_regex_pattern(pattern)),
+ pattern.callback,
+ )
+ for pattern in self.patterns
]
+ should_include = [inspector.should_include_endpoint(*pair) for pair in pairs]
+
expected = [False, False, True]
assert should_include == expected
@@ -1102,154 +1401,153 @@ class NamingCollisionViewSet(GenericViewSet):
"""
Example via: https://stackoverflow.com/questions/43778668/django-rest-framwork-occured-typeerror-link-object-does-not-support-item-ass/
"""
+
permision_class = ()
@action(detail=False)
def detail(self, request):
return {}
- @action(detail=False, url_path='detail/export')
+ @action(detail=False, url_path="detail/export")
def detail_export(self, request):
return {}
naming_collisions_router = SimpleRouter()
-naming_collisions_router.register(r'collision', NamingCollisionViewSet, basename="collision")
+naming_collisions_router.register(
+ r"collision", NamingCollisionViewSet, basename="collision"
+)
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class TestURLNamingCollisions(TestCase):
"""
Ref: https://github.com/encode/django-rest-framework/issues/4704
"""
- def test_manually_routing_nested_routes(self):
- patterns = [
- url(r'^test', simple_fbv),
- url(r'^test/list/', simple_fbv),
- ]
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ def test_manually_routing_nested_routes(self):
+ patterns = [url(r"^test", simple_fbv), url(r"^test/list/", simple_fbv)]
+
+ generator = SchemaGenerator(title="Naming Colisions", patterns=patterns)
schema = generator.get_schema()
expected = coreapi.Document(
- url='',
- title='Naming Colisions',
+ url="",
+ title="Naming Colisions",
content={
- 'test': {
- 'list': {
- 'list': coreapi.Link(url='/test/list/', action='get')
- },
- 'list_0': coreapi.Link(url='/test', action='get')
+ "test": {
+ "list": {"list": coreapi.Link(url="/test/list/", action="get")},
+ "list_0": coreapi.Link(url="/test", action="get"),
}
- }
+ },
)
assert expected == schema
def _verify_cbv_links(self, loc, url, methods=None, suffixes=None):
if methods is None:
- methods = ('read', 'update', 'partial_update', 'delete')
+ methods = ("read", "update", "partial_update", "delete")
if suffixes is None:
suffixes = (None for m in methods)
for method, suffix in zip(methods, suffixes):
if suffix is not None:
- key = '{}_{}'.format(method, suffix)
+ key = "{}_{}".format(method, suffix)
else:
key = method
assert loc[key].url == url
def test_manually_routing_generic_view(self):
patterns = [
- url(r'^test', NamingCollisionView.as_view()),
- url(r'^test/retrieve/', NamingCollisionView.as_view()),
- url(r'^test/update/', NamingCollisionView.as_view()),
-
+ url(r"^test", NamingCollisionView.as_view()),
+ url(r"^test/retrieve/", NamingCollisionView.as_view()),
+ url(r"^test/update/", NamingCollisionView.as_view()),
# Fails with method names:
- url(r'^test/get/', NamingCollisionView.as_view()),
- url(r'^test/put/', NamingCollisionView.as_view()),
- url(r'^test/delete/', NamingCollisionView.as_view()),
+ url(r"^test/get/", NamingCollisionView.as_view()),
+ url(r"^test/put/", NamingCollisionView.as_view()),
+ url(r"^test/delete/", NamingCollisionView.as_view()),
]
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ generator = SchemaGenerator(title="Naming Colisions", patterns=patterns)
schema = generator.get_schema()
- self._verify_cbv_links(schema['test']['delete'], '/test/delete/')
- self._verify_cbv_links(schema['test']['put'], '/test/put/')
- self._verify_cbv_links(schema['test']['get'], '/test/get/')
- self._verify_cbv_links(schema['test']['update'], '/test/update/')
- self._verify_cbv_links(schema['test']['retrieve'], '/test/retrieve/')
- self._verify_cbv_links(schema['test'], '/test', suffixes=(None, '0', None, '0'))
+ self._verify_cbv_links(schema["test"]["delete"], "/test/delete/")
+ self._verify_cbv_links(schema["test"]["put"], "/test/put/")
+ self._verify_cbv_links(schema["test"]["get"], "/test/get/")
+ self._verify_cbv_links(schema["test"]["update"], "/test/update/")
+ self._verify_cbv_links(schema["test"]["retrieve"], "/test/retrieve/")
+ self._verify_cbv_links(schema["test"], "/test", suffixes=(None, "0", None, "0"))
def test_from_router(self):
- patterns = [
- url(r'from-router', include(naming_collisions_router.urls)),
- ]
+ patterns = [url(r"from-router", include(naming_collisions_router.urls))]
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ generator = SchemaGenerator(title="Naming Colisions", patterns=patterns)
schema = generator.get_schema()
# not important here
- desc_0 = schema['detail']['detail_export'].description
- desc_1 = schema['detail_0'].description
+ desc_0 = schema["detail"]["detail_export"].description
+ desc_1 = schema["detail_0"].description
expected = coreapi.Document(
- url='',
- title='Naming Colisions',
+ url="",
+ title="Naming Colisions",
content={
- 'detail': {
- 'detail_export': coreapi.Link(
- url='/from-routercollision/detail/export/',
- action='get',
- description=desc_0)
+ "detail": {
+ "detail_export": coreapi.Link(
+ url="/from-routercollision/detail/export/",
+ action="get",
+ description=desc_0,
+ )
},
- 'detail_0': coreapi.Link(
- url='/from-routercollision/detail/',
- action='get',
- description=desc_1
- )
- }
+ "detail_0": coreapi.Link(
+ url="/from-routercollision/detail/",
+ action="get",
+ description=desc_1,
+ ),
+ },
)
assert schema == expected
def test_url_under_same_key_not_replaced(self):
patterns = [
- url(r'example/(?P\d+)/$', BasicNamingCollisionView.as_view()),
- url(r'example/(?P\w+)/$', BasicNamingCollisionView.as_view()),
+ url(r"example/(?P\d+)/$", BasicNamingCollisionView.as_view()),
+ url(r"example/(?P\w+)/$", BasicNamingCollisionView.as_view()),
]
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ generator = SchemaGenerator(title="Naming Colisions", patterns=patterns)
schema = generator.get_schema()
- assert schema['example']['read'].url == '/example/{id}/'
- assert schema['example']['read_0'].url == '/example/{slug}/'
+ assert schema["example"]["read"].url == "/example/{id}/"
+ assert schema["example"]["read_0"].url == "/example/{slug}/"
def test_url_under_same_key_not_replaced_another(self):
patterns = [
- url(r'^test/list/', simple_fbv),
- url(r'^test/(?P\d+)/list/', simple_fbv),
+ url(r"^test/list/", simple_fbv),
+ url(r"^test/(?P\d+)/list/", simple_fbv),
]
- generator = SchemaGenerator(title='Naming Colisions', patterns=patterns)
+ generator = SchemaGenerator(title="Naming Colisions", patterns=patterns)
schema = generator.get_schema()
- assert schema['test']['list']['list'].url == '/test/list/'
- assert schema['test']['list']['list_0'].url == '/test/{id}/list/'
+ assert schema["test"]["list"]["list"].url == "/test/list/"
+ assert schema["test"]["list"]["list_0"].url == "/test/{id}/list/"
def test_is_list_view_recognises_retrieve_view_subclasses():
class TestView(generics.RetrieveAPIView):
pass
- path = '/looks/like/a/list/view/'
- method = 'get'
+ path = "/looks/like/a/list/view/"
+ method = "get"
view = TestView()
is_list = is_list_view(path, method, view)
- assert not is_list, "RetrieveAPIView subclasses should not be classified as list views."
+ assert (
+ not is_list
+ ), "RetrieveAPIView subclasses should not be classified as list views."
def test_head_and_options_methods_are_excluded():
@@ -1262,20 +1560,19 @@ def test_head_and_options_methods_are_excluded():
Initial cases here shown to be working as expected.
"""
- @api_view(['options', 'get'])
+ @api_view(["options", "get"])
def fbv(request):
pass
inspector = EndpointEnumerator()
- path = '/a/path/'
+ path = "/a/path/"
callback = fbv
assert inspector.should_include_endpoint(path, callback)
assert inspector.get_allowed_methods(callback) == ["GET"]
class AnAPIView(APIView):
-
def get(self, request, *args, **kwargs):
pass
@@ -1288,73 +1585,69 @@ def test_head_and_options_methods_are_excluded():
assert inspector.get_allowed_methods(callback) == ["GET"]
class AViewSet(ModelViewSet):
-
- @action(methods=['options', 'get'], detail=True)
+ @action(methods=["options", "get"], detail=True)
def custom_action(self, request, pk):
pass
- callback = AViewSet.as_view({
- "options": "custom_action",
- "get": "custom_action"
- })
+ callback = AViewSet.as_view({"options": "custom_action", "get": "custom_action"})
assert inspector.should_include_endpoint(path, callback)
assert inspector.get_allowed_methods(callback) == ["GET"]
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class TestAutoSchemaAllowsFilters(object):
class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter]
def _test(self, method):
view = self.MockAPIView()
- fields = view.schema.get_filter_fields('', method)
+ fields = view.schema.get_filter_fields("", method)
field_names = [f.name for f in fields]
- return 'ordering' in field_names
+ return "ordering" in field_names
def test_get(self):
- assert self._test('get')
+ assert self._test("get")
def test_GET(self):
- assert self._test('GET')
+ assert self._test("GET")
def test_put(self):
- assert self._test('put')
+ assert self._test("put")
def test_PUT(self):
- assert self._test('PUT')
+ assert self._test("PUT")
def test_patch(self):
- assert self._test('patch')
+ assert self._test("patch")
def test_PATCH(self):
- assert self._test('PATCH')
+ assert self._test("PATCH")
def test_delete(self):
- assert self._test('delete')
+ assert self._test("delete")
def test_DELETE(self):
- assert self._test('DELETE')
+ assert self._test("DELETE")
def test_post(self):
- assert not self._test('post')
+ assert not self._test("post")
def test_POST(self):
- assert not self._test('POST')
+ assert not self._test("POST")
def test_foo(self):
- assert not self._test('foo')
+ assert not self._test("foo")
def test_FOO(self):
- assert not self._test('FOO')
+ assert not self._test("FOO")
-@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
+@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_schema_handles_exception():
schema_view = get_schema_view(permission_classes=[DenyAllUsingPermissionDenied])
- request = factory.get('/')
+ request = factory.get("/")
response = schema_view(request)
response.render()
assert response.status_code == 403
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
index 0f1e81965..514df5308 100644
--- a/tests/test_serializer.py
+++ b/tests/test_serializer.py
@@ -13,11 +13,10 @@ from rest_framework import exceptions, fields, relations, serializers
from rest_framework.compat import Mapping, unicode_repr
from rest_framework.fields import Field
-from .models import (
- ForeignKeyTarget, NestedForeignKeySource, NullableForeignKeySource
-)
+from .models import ForeignKeyTarget, NestedForeignKeySource, NullableForeignKeySource
from .utils import MockObject
+
try:
from collections import ChainMap
except ImportError:
@@ -27,25 +26,26 @@ except ImportError:
# Test serializer fields imports.
# -------------------------------
+
class TestFieldImports:
def is_field(self, name, value):
return (
- isinstance(value, type) and
- issubclass(value, Field) and
- not name.startswith('_')
+ isinstance(value, type)
+ and issubclass(value, Field)
+ and not name.startswith("_")
)
def test_fields(self):
msg = "Expected `fields.%s` to be imported in `serializers`"
field_classes = [
- key for key, value
- in inspect.getmembers(fields)
+ key
+ for key, value in inspect.getmembers(fields)
if self.is_field(key, value)
]
# sanity check
- assert 'Field' in field_classes
- assert 'BooleanField' in field_classes
+ assert "Field" in field_classes
+ assert "BooleanField" in field_classes
for field in field_classes:
assert hasattr(serializers, field), msg % field
@@ -53,13 +53,13 @@ class TestFieldImports:
def test_relations(self):
msg = "Expected `relations.%s` to be imported in `serializers`"
field_classes = [
- key for key, value
- in inspect.getmembers(relations)
+ key
+ for key, value in inspect.getmembers(relations)
if self.is_field(key, value)
]
# sanity check
- assert 'RelatedField' in field_classes
+ assert "RelatedField" in field_classes
for field in field_classes:
assert hasattr(serializers, field), msg % field
@@ -68,47 +68,52 @@ class TestFieldImports:
# Tests for core functionality.
# -----------------------------
+
class TestSerializer:
def setup(self):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
+
self.Serializer = ExampleSerializer
def test_valid_serializer(self):
- serializer = self.Serializer(data={'char': 'abc', 'integer': 123})
+ serializer = self.Serializer(data={"char": "abc", "integer": 123})
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc', 'integer': 123}
- assert serializer.data == {'char': 'abc', 'integer': 123}
+ assert serializer.validated_data == {"char": "abc", "integer": 123}
+ assert serializer.data == {"char": "abc", "integer": 123}
assert serializer.errors == {}
def test_invalid_serializer(self):
- serializer = self.Serializer(data={'char': 'abc'})
+ serializer = self.Serializer(data={"char": "abc"})
assert not serializer.is_valid()
assert serializer.validated_data == {}
- assert serializer.data == {'char': 'abc'}
- assert serializer.errors == {'integer': ['This field is required.']}
+ assert serializer.data == {"char": "abc"}
+ assert serializer.errors == {"integer": ["This field is required."]}
def test_invalid_datatype(self):
- serializer = self.Serializer(data=[{'char': 'abc'}])
+ serializer = self.Serializer(data=[{"char": "abc"}])
assert not serializer.is_valid()
assert serializer.validated_data == {}
assert serializer.data == {}
- assert serializer.errors == {'non_field_errors': ['Invalid data. Expected a dictionary, but got list.']}
+ assert serializer.errors == {
+ "non_field_errors": ["Invalid data. Expected a dictionary, but got list."]
+ }
def test_partial_validation(self):
- serializer = self.Serializer(data={'char': 'abc'}, partial=True)
+ serializer = self.Serializer(data={"char": "abc"}, partial=True)
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc'}
+ assert serializer.validated_data == {"char": "abc"}
assert serializer.errors == {}
def test_empty_serializer(self):
serializer = self.Serializer()
- assert serializer.data == {'char': '', 'integer': None}
+ assert serializer.data == {"char": "", "integer": None}
def test_missing_attribute_during_serialization(self):
class MissingAttributes:
pass
+
instance = MissingAttributes()
serializer = self.Serializer(instance)
with pytest.raises(AttributeError):
@@ -117,10 +122,11 @@ class TestSerializer:
def test_data_access_before_save_raises_error(self):
def create(validated_data):
return validated_data
- serializer = self.Serializer(data={'char': 'abc', 'integer': 123})
+
+ serializer = self.Serializer(data={"char": "abc", "integer": 123})
serializer.create = create
assert serializer.is_valid()
- assert serializer.data == {'char': 'abc', 'integer': 123}
+ assert serializer.data == {"char": "abc", "integer": 123}
with pytest.raises(AssertionError):
serializer.save()
@@ -128,31 +134,31 @@ class TestSerializer:
data = None
serializer = self.Serializer(data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'non_field_errors': ['No data provided']}
+ assert serializer.errors == {"non_field_errors": ["No data provided"]}
- @unittest.skipUnless(ChainMap, 'requires python 3.3')
+ @unittest.skipUnless(ChainMap, "requires python 3.3")
def test_serialize_chainmap(self):
- data = ChainMap({'char': 'abc'}, {'integer': 123})
+ data = ChainMap({"char": "abc"}, {"integer": 123})
serializer = self.Serializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc', 'integer': 123}
+ assert serializer.validated_data == {"char": "abc", "integer": 123}
assert serializer.errors == {}
def test_serialize_custom_mapping(self):
class SinglePurposeMapping(Mapping):
def __getitem__(self, key):
- return 'abc' if key == 'char' else 123
+ return "abc" if key == "char" else 123
def __iter__(self):
- yield 'char'
- yield 'integer'
+ yield "char"
+ yield "integer"
def __len__(self):
return 2
serializer = self.Serializer(data=SinglePurposeMapping())
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc', 'integer': 123}
+ assert serializer.validated_data == {"char": "abc", "integer": 123}
assert serializer.errors == {}
def test_custom_to_internal_value(self):
@@ -160,6 +166,7 @@ class TestSerializer:
to_internal_value() is expected to return a dict, but subclasses may
return application specific type.
"""
+
class Point(object):
def __init__(self, srid, x, y):
self.srid = srid
@@ -167,14 +174,16 @@ class TestSerializer:
# Declares a serializer that converts data into an object
class NestedPointSerializer(serializers.Serializer):
- longitude = serializers.FloatField(source='x')
- latitude = serializers.FloatField(source='y')
+ longitude = serializers.FloatField(source="x")
+ latitude = serializers.FloatField(source="y")
def to_internal_value(self, data):
kwargs = super(NestedPointSerializer, self).to_internal_value(data)
return Point(srid=4326, **kwargs)
- serializer = NestedPointSerializer(data={'longitude': 6.958307, 'latitude': 50.941357})
+ serializer = NestedPointSerializer(
+ data={"longitude": 6.958307, "latitude": 50.941357}
+ )
assert serializer.is_valid()
assert isinstance(serializer.validated_data, Point)
assert serializer.validated_data.srid == 4326
@@ -186,9 +195,10 @@ class TestSerializer:
"""
Ensure `validators` parameter is compatible with reasonable iterables.
"""
- data = {'char': 'abc', 'integer': 123}
+ data = {"char": "abc", "integer": 123}
for validators in ([], (), set()):
+
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField(validators=validators)
integer = serializers.IntegerField()
@@ -199,9 +209,14 @@ class TestSerializer:
assert serializer.errors == {}
def raise_exception(value):
- raise exceptions.ValidationError('Raised error')
+ raise exceptions.ValidationError("Raised error")
+
+ for validators in (
+ [raise_exception],
+ (raise_exception,),
+ set([raise_exception]),
+ ):
- for validators in ([raise_exception], (raise_exception,), set([raise_exception])):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField(validators=validators)
integer = serializers.IntegerField()
@@ -210,9 +225,9 @@ class TestSerializer:
assert not serializer.is_valid()
assert serializer.data == data
assert serializer.validated_data == {}
- assert serializer.errors == {'char': [
- exceptions.ErrorDetail(string='Raised error', code='invalid')
- ]}
+ assert serializer.errors == {
+ "char": [exceptions.ErrorDetail(string="Raised error", code="invalid")]
+ }
class TestValidateMethod:
@@ -222,11 +237,11 @@ class TestValidateMethod:
integer = serializers.IntegerField()
def validate(self, attrs):
- raise serializers.ValidationError('Non field error')
+ raise serializers.ValidationError("Non field error")
- serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
+ serializer = ExampleSerializer(data={"char": "abc", "integer": 123})
assert not serializer.is_valid()
- assert serializer.errors == {'non_field_errors': ['Non field error']}
+ assert serializer.errors == {"non_field_errors": ["Non field error"]}
def test_field_error_validate_method(self):
class ExampleSerializer(serializers.Serializer):
@@ -234,29 +249,22 @@ class TestValidateMethod:
integer = serializers.IntegerField()
def validate(self, attrs):
- raise serializers.ValidationError({'char': 'Field error'})
+ raise serializers.ValidationError({"char": "Field error"})
- serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
+ serializer = ExampleSerializer(data={"char": "abc", "integer": 123})
assert not serializer.is_valid()
- assert serializer.errors == {'char': ['Field error']}
+ assert serializer.errors == {"char": ["Field error"]}
class TestBaseSerializer:
def setup(self):
class ExampleSerializer(serializers.BaseSerializer):
def to_representation(self, obj):
- return {
- 'id': obj['id'],
- 'email': obj['name'] + '@' + obj['domain']
- }
+ return {"id": obj["id"], "email": obj["name"] + "@" + obj["domain"]}
def to_internal_value(self, data):
- name, domain = str(data['email']).split('@')
- return {
- 'id': int(data['id']),
- 'name': name,
- 'domain': domain,
- }
+ name, domain = str(data["email"]).split("@")
+ return {"id": int(data["id"]), "name": name, "domain": domain}
self.Serializer = ExampleSerializer
@@ -272,56 +280,56 @@ class TestBaseSerializer:
serializer.create(None)
def test_access_to_data_attribute_before_validation_raises_error(self):
- serializer = serializers.BaseSerializer(data={'foo': 'bar'})
+ serializer = serializers.BaseSerializer(data={"foo": "bar"})
with pytest.raises(AssertionError):
serializer.data
def test_access_to_errors_attribute_before_validation_raises_error(self):
- serializer = serializers.BaseSerializer(data={'foo': 'bar'})
+ serializer = serializers.BaseSerializer(data={"foo": "bar"})
with pytest.raises(AssertionError):
serializer.errors
def test_access_to_validated_data_attribute_before_validation_raises_error(self):
- serializer = serializers.BaseSerializer(data={'foo': 'bar'})
+ serializer = serializers.BaseSerializer(data={"foo": "bar"})
with pytest.raises(AssertionError):
serializer.validated_data
def test_serialize_instance(self):
- instance = {'id': 1, 'name': 'tom', 'domain': 'example.com'}
+ instance = {"id": 1, "name": "tom", "domain": "example.com"}
serializer = self.Serializer(instance)
- assert serializer.data == {'id': 1, 'email': 'tom@example.com'}
+ assert serializer.data == {"id": 1, "email": "tom@example.com"}
def test_serialize_list(self):
instances = [
- {'id': 1, 'name': 'tom', 'domain': 'example.com'},
- {'id': 2, 'name': 'ann', 'domain': 'example.com'},
+ {"id": 1, "name": "tom", "domain": "example.com"},
+ {"id": 2, "name": "ann", "domain": "example.com"},
]
serializer = self.Serializer(instances, many=True)
assert serializer.data == [
- {'id': 1, 'email': 'tom@example.com'},
- {'id': 2, 'email': 'ann@example.com'}
+ {"id": 1, "email": "tom@example.com"},
+ {"id": 2, "email": "ann@example.com"},
]
def test_validate_data(self):
- data = {'id': 1, 'email': 'tom@example.com'}
+ data = {"id": 1, "email": "tom@example.com"}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
- 'id': 1,
- 'name': 'tom',
- 'domain': 'example.com'
+ "id": 1,
+ "name": "tom",
+ "domain": "example.com",
}
def test_validate_list(self):
data = [
- {'id': 1, 'email': 'tom@example.com'},
- {'id': 2, 'email': 'ann@example.com'},
+ {"id": 1, "email": "tom@example.com"},
+ {"id": 2, "email": "ann@example.com"},
]
serializer = self.Serializer(data=data, many=True)
assert serializer.is_valid()
assert serializer.validated_data == [
- {'id': 1, 'name': 'tom', 'domain': 'example.com'},
- {'id': 2, 'name': 'ann', 'domain': 'example.com'}
+ {"id": 1, "name": "tom", "domain": "example.com"},
+ {"id": 2, "name": "ann", "domain": "example.com"},
]
@@ -333,10 +341,8 @@ class TestStarredSource:
nested_field = NestedField(source='*')
"""
- data = {
- 'nested1': {'a': 1, 'b': 2},
- 'nested2': {'c': 3, 'd': 4}
- }
+
+ data = {"nested1": {"a": 1, "b": 2}, "nested2": {"c": 3, "d": 4}}
def setup(self):
class NestedSerializer1(serializers.Serializer):
@@ -348,8 +354,8 @@ class TestStarredSource:
d = serializers.IntegerField()
class TestSerializer(serializers.Serializer):
- nested1 = NestedSerializer1(source='*')
- nested2 = NestedSerializer2(source='*')
+ nested1 = NestedSerializer1(source="*")
+ nested2 = NestedSerializer2(source="*")
self.Serializer = TestSerializer
@@ -359,18 +365,13 @@ class TestStarredSource:
"""
serializer = self.Serializer(data=self.data)
assert serializer.is_valid()
- assert serializer.validated_data == {
- 'a': 1,
- 'b': 2,
- 'c': 3,
- 'd': 4
- }
+ assert serializer.validated_data == {"a": 1, "b": 2, "c": 3, "d": 4}
def test_nested_serialize(self):
"""
An object can be serialized into a nested representation.
"""
- instance = {'a': 1, 'b': 2, 'c': 3, 'd': 4}
+ instance = {"a": 1, "b": 2, "c": 3, "d": 4}
serializer = self.Serializer(instance)
assert serializer.data == self.data
@@ -403,7 +404,7 @@ class TestUnicodeRepr:
class ExampleObject:
def __init__(self):
- self.example = '한êµ'
+ self.example = "한êµ"
def __repr__(self):
return unicode_repr(self.example)
@@ -418,18 +419,20 @@ class TestNotRequiredOutput:
"""
'required=False' should allow a dictionary key to be missing in output.
"""
+
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False)
included = serializers.CharField()
- serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer = ExampleSerializer(data={"included": "abc"})
serializer.is_valid()
- assert serializer.data == {'included': 'abc'}
+ assert serializer.data == {"included": "abc"}
def test_not_required_output_for_object(self):
"""
'required=False' should allow an object attribute to be missing in output.
"""
+
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False)
included = serializers.CharField()
@@ -437,115 +440,146 @@ class TestNotRequiredOutput:
def create(self, validated_data):
return MockObject(**validated_data)
- serializer = ExampleSerializer(data={'included': 'abc'})
+ serializer = ExampleSerializer(data={"included": "abc"})
serializer.is_valid()
serializer.save()
- assert serializer.data == {'included': 'abc'}
+ assert serializer.data == {"included": "abc"}
class TestDefaultOutput:
def setup(self):
class ExampleSerializer(serializers.Serializer):
- has_default = serializers.CharField(default='x')
- has_default_callable = serializers.CharField(default=lambda: 'y')
+ has_default = serializers.CharField(default="x")
+ has_default_callable = serializers.CharField(default=lambda: "y")
no_default = serializers.CharField()
+
self.Serializer = ExampleSerializer
def test_default_used_for_dict(self):
"""
'default="something"' should be used if dictionary key is missing from input.
"""
- serializer = self.Serializer({'no_default': 'abc'})
- assert serializer.data == {'has_default': 'x', 'has_default_callable': 'y', 'no_default': 'abc'}
+ serializer = self.Serializer({"no_default": "abc"})
+ assert serializer.data == {
+ "has_default": "x",
+ "has_default_callable": "y",
+ "no_default": "abc",
+ }
def test_default_used_for_object(self):
"""
'default="something"' should be used if object attribute is missing from input.
"""
- instance = MockObject(no_default='abc')
+ instance = MockObject(no_default="abc")
serializer = self.Serializer(instance)
- assert serializer.data == {'has_default': 'x', 'has_default_callable': 'y', 'no_default': 'abc'}
+ assert serializer.data == {
+ "has_default": "x",
+ "has_default_callable": "y",
+ "no_default": "abc",
+ }
def test_default_not_used_when_in_dict(self):
"""
'default="something"' should not be used if dictionary key is present in input.
"""
- serializer = self.Serializer({'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'})
- assert serializer.data == {'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'}
+ serializer = self.Serializer(
+ {"has_default": "def", "has_default_callable": "ghi", "no_default": "abc"}
+ )
+ assert serializer.data == {
+ "has_default": "def",
+ "has_default_callable": "ghi",
+ "no_default": "abc",
+ }
def test_default_not_used_when_in_object(self):
"""
'default="something"' should not be used if object attribute is present in input.
"""
- instance = MockObject(has_default='def', has_default_callable='ghi', no_default='abc')
+ instance = MockObject(
+ has_default="def", has_default_callable="ghi", no_default="abc"
+ )
serializer = self.Serializer(instance)
- assert serializer.data == {'has_default': 'def', 'has_default_callable': 'ghi', 'no_default': 'abc'}
+ assert serializer.data == {
+ "has_default": "def",
+ "has_default_callable": "ghi",
+ "no_default": "abc",
+ }
def test_default_for_dotted_source(self):
"""
'default="something"' should be used when a traversed attribute is missing from input.
"""
+
class Serializer(serializers.Serializer):
- traversed = serializers.CharField(default='x', source='traversed.attr')
+ traversed = serializers.CharField(default="x", source="traversed.attr")
- assert Serializer({}).data == {'traversed': 'x'}
- assert Serializer({'traversed': {}}).data == {'traversed': 'x'}
- assert Serializer({'traversed': None}).data == {'traversed': 'x'}
+ assert Serializer({}).data == {"traversed": "x"}
+ assert Serializer({"traversed": {}}).data == {"traversed": "x"}
+ assert Serializer({"traversed": None}).data == {"traversed": "x"}
- assert Serializer({'traversed': {'attr': 'abc'}}).data == {'traversed': 'abc'}
+ assert Serializer({"traversed": {"attr": "abc"}}).data == {"traversed": "abc"}
def test_default_for_multiple_dotted_source(self):
class Serializer(serializers.Serializer):
- c = serializers.CharField(default='x', source='a.b.c')
+ c = serializers.CharField(default="x", source="a.b.c")
- assert Serializer({}).data == {'c': 'x'}
- assert Serializer({'a': {}}).data == {'c': 'x'}
- assert Serializer({'a': None}).data == {'c': 'x'}
- assert Serializer({'a': {'b': {}}}).data == {'c': 'x'}
- assert Serializer({'a': {'b': None}}).data == {'c': 'x'}
+ assert Serializer({}).data == {"c": "x"}
+ assert Serializer({"a": {}}).data == {"c": "x"}
+ assert Serializer({"a": None}).data == {"c": "x"}
+ assert Serializer({"a": {"b": {}}}).data == {"c": "x"}
+ assert Serializer({"a": {"b": None}}).data == {"c": "x"}
- assert Serializer({'a': {'b': {'c': 'abc'}}}).data == {'c': 'abc'}
+ assert Serializer({"a": {"b": {"c": "abc"}}}).data == {"c": "abc"}
# Same test using model objects to exercise both paths in
# rest_framework.fields.get_attribute() (#5880)
class ModelSerializer(serializers.Serializer):
- target = serializers.CharField(default='x', source='target.target.name')
+ target = serializers.CharField(default="x", source="target.target.name")
a = NestedForeignKeySource(name="Root Object", target=None)
- assert ModelSerializer(a).data == {'target': 'x'}
+ assert ModelSerializer(a).data == {"target": "x"}
b = NullableForeignKeySource(name="Intermediary Object", target=None)
a.target = b
- assert ModelSerializer(a).data == {'target': 'x'}
+ assert ModelSerializer(a).data == {"target": "x"}
c = ForeignKeyTarget(name="Target Object")
b.target = c
- assert ModelSerializer(a).data == {'target': 'Target Object'}
+ assert ModelSerializer(a).data == {"target": "Target Object"}
def test_default_for_nested_serializer(self):
class NestedSerializer(serializers.Serializer):
- a = serializers.CharField(default='1')
- c = serializers.CharField(default='2', source='b.c')
+ a = serializers.CharField(default="1")
+ c = serializers.CharField(default="2", source="b.c")
class Serializer(serializers.Serializer):
nested = NestedSerializer()
- assert Serializer({'nested': None}).data == {'nested': None}
- assert Serializer({'nested': {}}).data == {'nested': {'a': '1', 'c': '2'}}
- assert Serializer({'nested': {'a': '3', 'b': {}}}).data == {'nested': {'a': '3', 'c': '2'}}
- assert Serializer({'nested': {'a': '3', 'b': {'c': '4'}}}).data == {'nested': {'a': '3', 'c': '4'}}
+ assert Serializer({"nested": None}).data == {"nested": None}
+ assert Serializer({"nested": {}}).data == {"nested": {"a": "1", "c": "2"}}
+ assert Serializer({"nested": {"a": "3", "b": {}}}).data == {
+ "nested": {"a": "3", "c": "2"}
+ }
+ assert Serializer({"nested": {"a": "3", "b": {"c": "4"}}}).data == {
+ "nested": {"a": "3", "c": "4"}
+ }
def test_default_for_allow_null(self):
"""
Without an explicit default, allow_null implies default=None when serializing. #5518 #5708
"""
+
class Serializer(serializers.Serializer):
foo = serializers.CharField()
- bar = serializers.CharField(source='foo.bar', allow_null=True)
+ bar = serializers.CharField(source="foo.bar", allow_null=True)
optional = serializers.CharField(required=False, allow_null=True)
# allow_null=True should imply default=None when serialising:
- assert Serializer({'foo': None}).data == {'foo': None, 'bar': None, 'optional': None, }
+ assert Serializer({"foo": None}).data == {
+ "foo": None,
+ "bar": None,
+ "optional": None,
+ }
class TestCacheSerializerData:
@@ -554,54 +588,57 @@ class TestCacheSerializerData:
Caching serializer data with pickle will drop the serializer info,
but does preserve the data itself.
"""
+
class ExampleSerializer(serializers.Serializer):
field1 = serializers.CharField()
field2 = serializers.CharField()
- serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'})
+ serializer = ExampleSerializer({"field1": "a", "field2": "b"})
pickled = pickle.dumps(serializer.data)
data = pickle.loads(pickled)
- assert data == {'field1': 'a', 'field2': 'b'}
+ assert data == {"field1": "a", "field2": "b"}
class TestDefaultInclusions:
def setup(self):
class ExampleSerializer(serializers.Serializer):
- char = serializers.CharField(default='abc')
+ char = serializers.CharField(default="abc")
integer = serializers.IntegerField()
+
self.Serializer = ExampleSerializer
def test_default_should_included_on_create(self):
- serializer = self.Serializer(data={'integer': 456})
+ serializer = self.Serializer(data={"integer": 456})
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc', 'integer': 456}
+ assert serializer.validated_data == {"char": "abc", "integer": 456}
assert serializer.errors == {}
def test_default_should_be_included_on_update(self):
- instance = MockObject(char='def', integer=123)
- serializer = self.Serializer(instance, data={'integer': 456})
+ instance = MockObject(char="def", integer=123)
+ serializer = self.Serializer(instance, data={"integer": 456})
assert serializer.is_valid()
- assert serializer.validated_data == {'char': 'abc', 'integer': 456}
+ assert serializer.validated_data == {"char": "abc", "integer": 456}
assert serializer.errors == {}
def test_default_should_not_be_included_on_partial_update(self):
- instance = MockObject(char='def', integer=123)
- serializer = self.Serializer(instance, data={'integer': 456}, partial=True)
+ instance = MockObject(char="def", integer=123)
+ serializer = self.Serializer(instance, data={"integer": 456}, partial=True)
assert serializer.is_valid()
- assert serializer.validated_data == {'integer': 456}
+ assert serializer.validated_data == {"integer": 456}
assert serializer.errors == {}
class TestSerializerValidationWithCompiledRegexField:
def setup(self):
class ExampleSerializer(serializers.Serializer):
- name = serializers.RegexField(re.compile(r'\d'), required=True)
+ name = serializers.RegexField(re.compile(r"\d"), required=True)
+
self.Serializer = ExampleSerializer
def test_validation_success(self):
- serializer = self.Serializer(data={'name': '2'})
+ serializer = self.Serializer(data={"name": "2"})
assert serializer.is_valid()
- assert serializer.validated_data == {'name': '2'}
+ assert serializer.validated_data == {"name": "2"}
assert serializer.errors == {}
@@ -616,9 +653,9 @@ class Test2555Regression:
class ParentSerializer(serializers.Serializer):
nested = NestedSerializer()
- serializer = ParentSerializer(data={}, context={'foo': 'bar'})
- assert serializer.context == {'foo': 'bar'}
- assert serializer.fields['nested'].context == {'foo': 'bar'}
+ serializer = ParentSerializer(data={}, context={"foo": "bar"})
+ assert serializer.context == {"foo": "bar"}
+ assert serializer.fields["nested"].context == {"foo": "bar"}
class Test4606Regression:
@@ -626,6 +663,7 @@ class Test4606Regression:
class ExampleSerializer(serializers.Serializer):
name = serializers.CharField(required=True)
choices = serializers.CharField(required=True)
+
self.Serializer = ExampleSerializer
def test_4606_regression(self):
@@ -660,7 +698,7 @@ class TestDeclaredFieldInheritance:
class Parent(serializers.ModelSerializer):
class Meta:
model = MyModel
- fields = ['f1', 'f2']
+ fields = ["f1", "f2"]
class Child(Parent):
f1 = None
diff --git a/tests/test_serializer_bulk_update.py b/tests/test_serializer_bulk_update.py
index d9e5d7978..a4d6bf68c 100644
--- a/tests/test_serializer_bulk_update.py
+++ b/tests/test_serializer_bulk_update.py
@@ -29,18 +29,16 @@ class BulkCreateSerializerTests(TestCase):
data = [
{
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 2,
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
+ "id": 0,
+ "title": "The electric kool-aid acid test",
+ "author": "Tom Wolfe",
+ },
+ {"id": 1, "title": "If this is a man", "author": "Primo Levi"},
+ {
+ "id": 2,
+ "title": "The wind-up bird chronicle",
+ "author": "Haruki Murakami",
+ },
]
serializer = self.BookSerializer(data=data, many=True)
@@ -55,24 +53,18 @@ class BulkCreateSerializerTests(TestCase):
data = [
{
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
- }, {
- 'id': 1,
- 'title': 'If this is a man',
- 'author': 'Primo Levi'
- }, {
- 'id': 'foo',
- 'title': 'The wind-up bird chronicle',
- 'author': 'Haruki Murakami'
- }
- ]
- expected_errors = [
- {},
- {},
- {'id': ['A valid integer is required.']}
+ "id": 0,
+ "title": "The electric kool-aid acid test",
+ "author": "Tom Wolfe",
+ },
+ {"id": 1, "title": "If this is a man", "author": "Primo Levi"},
+ {
+ "id": "foo",
+ "title": "The wind-up bird chronicle",
+ "author": "Haruki Murakami",
+ },
]
+ expected_errors = [{}, {}, {"id": ["A valid integer is required."]}]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
@@ -83,16 +75,16 @@ class BulkCreateSerializerTests(TestCase):
"""
Data containing list of incorrect data type should return errors.
"""
- data = ['foo', 'bar', 'baz']
+ data = ["foo", "bar", "baz"]
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
text_type_string = six.text_type.__name__
- message = 'Invalid data. Expected a dictionary, but got %s.' % text_type_string
+ message = "Invalid data. Expected a dictionary, but got %s." % text_type_string
expected_errors = [
- {'non_field_errors': [message]},
- {'non_field_errors': [message]},
- {'non_field_errors': [message]}
+ {"non_field_errors": [message]},
+ {"non_field_errors": [message]},
+ {"non_field_errors": [message]},
]
assert serializer.errors == expected_errors
@@ -105,7 +97,9 @@ class BulkCreateSerializerTests(TestCase):
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
- expected_errors = {'non_field_errors': ['Expected a list of items but got type "int".']}
+ expected_errors = {
+ "non_field_errors": ['Expected a list of items but got type "int".']
+ }
assert serializer.errors == expected_errors
@@ -115,13 +109,15 @@ class BulkCreateSerializerTests(TestCase):
should return errors.
"""
data = {
- 'id': 0,
- 'title': 'The electric kool-aid acid test',
- 'author': 'Tom Wolfe'
+ "id": 0,
+ "title": "The electric kool-aid acid test",
+ "author": "Tom Wolfe",
}
serializer = self.BookSerializer(data=data, many=True)
assert serializer.is_valid() is False
- expected_errors = {'non_field_errors': ['Expected a list of items but got type "dict".']}
+ expected_errors = {
+ "non_field_errors": ['Expected a list of items but got type "dict".']
+ }
assert serializer.errors == expected_errors
diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py
index 12ed78b84..5bb818a93 100644
--- a/tests/test_serializer_lists.py
+++ b/tests/test_serializer_lists.py
@@ -8,6 +8,7 @@ class BasicObject:
"""
A mock object for testing serializer save behavior.
"""
+
def __init__(self, **kwargs):
self._data = kwargs
for key, value in kwargs.items():
@@ -31,6 +32,7 @@ class TestListSerializer:
def setup(self):
class IntegerListSerializer(serializers.ListSerializer):
child = serializers.IntegerField()
+
self.Serializer = IntegerListSerializer
def test_validate(self):
@@ -79,11 +81,11 @@ class TestListSerializerContainingNestedSerializer:
"""
input_data = [
{"integer": "123", "boolean": "true"},
- {"integer": "456", "boolean": "false"}
+ {"integer": "456", "boolean": "false"},
]
expected_output = [
{"integer": 123, "boolean": True},
- {"integer": 456, "boolean": False}
+ {"integer": 456, "boolean": False},
]
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
@@ -95,7 +97,7 @@ class TestListSerializerContainingNestedSerializer:
"""
input_data = [
{"integer": "123", "boolean": "true"},
- {"integer": "456", "boolean": "false"}
+ {"integer": "456", "boolean": "false"},
]
expected_output = [
BasicObject(integer=123, boolean=True),
@@ -111,11 +113,11 @@ class TestListSerializerContainingNestedSerializer:
"""
input_objects = [
BasicObject(integer=123, boolean=True),
- BasicObject(integer=456, boolean=False)
+ BasicObject(integer=456, boolean=False),
]
expected_output = [
{"integer": 123, "boolean": True},
- {"integer": 456, "boolean": False}
+ {"integer": 456, "boolean": False},
]
serializer = self.Serializer(input_objects)
assert serializer.data == expected_output
@@ -125,15 +127,17 @@ class TestListSerializerContainingNestedSerializer:
HTML input should be able to mock list structures using [x]
style prefixes.
"""
- input_data = MultiValueDict({
- "[0]integer": ["123"],
- "[0]boolean": ["true"],
- "[1]integer": ["456"],
- "[1]boolean": ["false"]
- })
+ input_data = MultiValueDict(
+ {
+ "[0]integer": ["123"],
+ "[0]boolean": ["true"],
+ "[1]integer": ["456"],
+ "[1]boolean": ["false"],
+ }
+ )
expected_output = [
{"integer": 123, "boolean": True},
- {"integer": 456, "boolean": False}
+ {"integer": 456, "boolean": False},
]
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
@@ -159,14 +163,8 @@ class TestNestedListSerializer:
"""
Validating a list of items should return a list of validated items.
"""
- input_data = {
- "integers": ["123", "456"],
- "booleans": ["true", "false"]
- }
- expected_output = {
- "integers": [123, 456],
- "booleans": [True, False]
- }
+ input_data = {"integers": ["123", "456"], "booleans": ["true", "false"]}
+ expected_output = {"integers": [123, 456], "booleans": [True, False]}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.validated_data == expected_output
@@ -176,14 +174,8 @@ class TestNestedListSerializer:
Creation with a list of items return an object with an attribute that
is a list of items.
"""
- input_data = {
- "integers": ["123", "456"],
- "booleans": ["true", "false"]
- }
- expected_output = BasicObject(
- integers=[123, 456],
- booleans=[True, False]
- )
+ input_data = {"integers": ["123", "456"], "booleans": ["true", "false"]}
+ expected_output = BasicObject(integers=[123, 456], booleans=[True, False])
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.save() == expected_output
@@ -192,14 +184,8 @@ class TestNestedListSerializer:
"""
Serialization of a list of items should return a list of items.
"""
- input_object = BasicObject(
- integers=[123, 456],
- booleans=[True, False]
- )
- expected_output = {
- "integers": [123, 456],
- "booleans": [True, False]
- }
+ input_object = BasicObject(integers=[123, 456], booleans=[True, False])
+ expected_output = {"integers": [123, 456], "booleans": [True, False]}
serializer = self.Serializer(input_object)
assert serializer.data == expected_output
@@ -208,16 +194,15 @@ class TestNestedListSerializer:
HTML input should be able to mock list structures using [x]
style prefixes.
"""
- input_data = MultiValueDict({
- "integers[0]": ["123"],
- "integers[1]": ["456"],
- "booleans[0]": ["true"],
- "booleans[1]": ["false"]
- })
- expected_output = {
- "integers": [123, 456],
- "booleans": [True, False]
- }
+ input_data = MultiValueDict(
+ {
+ "integers[0]": ["123"],
+ "integers[1]": ["456"],
+ "booleans[0]": ["true"],
+ "booleans[1]": ["false"],
+ }
+ )
+ expected_output = {"integers": [123, 456], "booleans": [True, False]}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.validated_data == expected_output
@@ -227,26 +212,22 @@ class TestNestedListOfListsSerializer:
def setup(self):
class TestSerializer(serializers.Serializer):
integers = serializers.ListSerializer(
- child=serializers.ListSerializer(
- child=serializers.IntegerField()
- )
+ child=serializers.ListSerializer(child=serializers.IntegerField())
)
booleans = serializers.ListSerializer(
- child=serializers.ListSerializer(
- child=serializers.BooleanField()
- )
+ child=serializers.ListSerializer(child=serializers.BooleanField())
)
self.Serializer = TestSerializer
def test_validate(self):
input_data = {
- 'integers': [['123', '456'], ['789', '0']],
- 'booleans': [['true', 'true'], ['false', 'true']]
+ "integers": [["123", "456"], ["789", "0"]],
+ "booleans": [["true", "true"], ["false", "true"]],
}
expected_output = {
"integers": [[123, 456], [789, 0]],
- "booleans": [[True, True], [False, True]]
+ "booleans": [[True, True], [False, True]],
}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
@@ -257,19 +238,21 @@ class TestNestedListOfListsSerializer:
HTML input should be able to mock lists of lists using [x][y]
style prefixes.
"""
- input_data = MultiValueDict({
- "integers[0][0]": ["123"],
- "integers[0][1]": ["456"],
- "integers[1][0]": ["789"],
- "integers[1][1]": ["000"],
- "booleans[0][0]": ["true"],
- "booleans[0][1]": ["true"],
- "booleans[1][0]": ["false"],
- "booleans[1][1]": ["true"]
- })
+ input_data = MultiValueDict(
+ {
+ "integers[0][0]": ["123"],
+ "integers[0][1]": ["456"],
+ "integers[1][0]": ["789"],
+ "integers[1][1]": ["000"],
+ "booleans[0][0]": ["true"],
+ "booleans[0][1]": ["true"],
+ "booleans[1][0]": ["false"],
+ "booleans[1][1]": ["true"],
+ }
+ )
expected_output = {
"integers": [[123, 456], [789, 0]],
- "booleans": [[True, True], [False, True]]
+ "booleans": [[True, True], [False, True]],
}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
@@ -278,10 +261,11 @@ class TestNestedListOfListsSerializer:
class TestListSerializerClass:
"""Tests for a custom list_serializer_class."""
+
def test_list_serializer_class_validate(self):
class CustomListSerializer(serializers.ListSerializer):
def validate(self, attrs):
- raise serializers.ValidationError('Non field error')
+ raise serializers.ValidationError("Non field error")
class TestSerializer(serializers.Serializer):
class Meta:
@@ -289,7 +273,7 @@ class TestListSerializerClass:
serializer = TestSerializer(data=[], many=True)
assert not serializer.is_valid()
- assert serializer.errors == {'non_field_errors': ['Non field error']}
+ assert serializer.errors == {"non_field_errors": ["Non field error"]}
class TestSerializerPartialUsage:
@@ -300,9 +284,11 @@ class TestSerializerPartialUsage:
Regression test for Github issue #2761.
"""
+
def test_partial_listfield(self):
class ListSerializer(serializers.Serializer):
listdata = serializers.ListField()
+
serializer = ListSerializer(data=MultiValueDict(), partial=True)
result = serializer.to_internal_value(data={})
assert "listdata" not in result
@@ -313,6 +299,7 @@ class TestSerializerPartialUsage:
def test_partial_multiplechoice(self):
class MultipleChoiceSerializer(serializers.Serializer):
multiplechoice = serializers.MultipleChoiceField(choices=[1, 2, 3])
+
serializer = MultipleChoiceSerializer(data=MultiValueDict(), partial=True)
result = serializer.to_internal_value(data={})
assert "multiplechoice" not in result
@@ -326,8 +313,8 @@ class TestSerializerPartialUsage:
store_field = serializers.IntegerField()
instance = [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
]
serializer = ListSerializer(instance, data=[], partial=True, many=True)
@@ -341,17 +328,16 @@ class TestSerializerPartialUsage:
store_field = serializers.IntegerField()
instance = [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
]
- input_data = [{'update_field': 31}, {'update_field': 41}]
+ input_data = [{"update_field": 31}, {"update_field": 41}]
updated_data_list = [
- {'update_field': 31, 'store_field': 12},
- {'update_field': 41, 'store_field': 22},
+ {"update_field": 31, "store_field": 12},
+ {"update_field": 41, "store_field": 22},
]
- serializer = ListSerializer(
- instance, data=input_data, partial=True, many=True)
+ serializer = ListSerializer(instance, data=input_data, partial=True, many=True)
assert serializer.is_valid()
for index, data in enumerate(serializer.validated_data):
@@ -366,16 +352,17 @@ class TestSerializerPartialUsage:
store_field = serializers.IntegerField()
instance = [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
]
serializer = ListSerializer(
- instance, data=[], allow_empty=False, partial=True, many=True)
+ instance, data=[], allow_empty=False, partial=True, many=True
+ )
assert not serializer.is_valid()
assert serializer.validated_data == []
assert len(serializer.errors) == 1
- assert serializer.errors['non_field_errors'][0] == 'This list may not be empty.'
+ assert serializer.errors["non_field_errors"][0] == "This list may not be empty."
def test_update_allow_empty_false(self):
class ListSerializer(serializers.Serializer):
@@ -383,17 +370,18 @@ class TestSerializerPartialUsage:
store_field = serializers.IntegerField()
instance = [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
]
- input_data = [{'update_field': 31}, {'update_field': 41}]
+ input_data = [{"update_field": 31}, {"update_field": 41}]
updated_data_list = [
- {'update_field': 31, 'store_field': 12},
- {'update_field': 41, 'store_field': 22},
+ {"update_field": 31, "store_field": 12},
+ {"update_field": 41, "store_field": 22},
]
serializer = ListSerializer(
- instance, data=input_data, allow_empty=False, partial=True, many=True)
+ instance, data=input_data, allow_empty=False, partial=True, many=True
+ )
assert serializer.is_valid()
for index, data in enumerate(serializer.validated_data):
@@ -412,11 +400,11 @@ class TestSerializerPartialUsage:
list_field = ListSerializer(many=True)
instance = {
- 'extra_field': 1,
- 'list_field': [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
- ]
+ "extra_field": 1,
+ "list_field": [
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
+ ],
}
serializer = Serializer(instance, data={}, partial=True)
@@ -434,25 +422,20 @@ class TestSerializerPartialUsage:
list_field = ListSerializer(many=True)
instance = {
- 'extra_field': 1,
- 'list_field': [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
- ]
- }
- input_data_1 = {'extra_field': 2}
- input_data_2 = {
- 'list_field': [
- {'update_field': 31},
- {'update_field': 41},
- ]
+ "extra_field": 1,
+ "list_field": [
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
+ ],
}
+ input_data_1 = {"extra_field": 2}
+ input_data_2 = {"list_field": [{"update_field": 31}, {"update_field": 41}]}
# data_1
serializer = Serializer(instance, data=input_data_1, partial=True)
assert serializer.is_valid()
assert len(serializer.validated_data) == 1
- assert serializer.validated_data['extra_field'] == 2
+ assert serializer.validated_data["extra_field"] == 2
assert serializer.errors == {}
# data_2
@@ -460,10 +443,10 @@ class TestSerializerPartialUsage:
assert serializer.is_valid()
updated_data_list = [
- {'update_field': 31, 'store_field': 12},
- {'update_field': 41, 'store_field': 22},
+ {"update_field": 31, "store_field": 12},
+ {"update_field": 41, "store_field": 22},
]
- for index, data in enumerate(serializer.validated_data['list_field']):
+ for index, data in enumerate(serializer.validated_data["list_field"]):
for key, value in data.items():
assert value == updated_data_list[index][key]
@@ -479,11 +462,11 @@ class TestSerializerPartialUsage:
list_field = ListSerializer(many=True, allow_empty=False)
instance = {
- 'extra_field': 1,
- 'list_field': [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
- ]
+ "extra_field": 1,
+ "list_field": [
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
+ ],
}
serializer = Serializer(instance, data={}, partial=True)
@@ -501,22 +484,17 @@ class TestSerializerPartialUsage:
list_field = ListSerializer(many=True, allow_empty=False)
instance = {
- 'extra_field': 1,
- 'list_field': [
- {'update_field': 11, 'store_field': 12},
- {'update_field': 21, 'store_field': 22},
- ]
- }
- input_data_1 = {'extra_field': 2}
- input_data_2 = {
- 'list_field': [
- {'update_field': 31},
- {'update_field': 41},
- ]
+ "extra_field": 1,
+ "list_field": [
+ {"update_field": 11, "store_field": 12},
+ {"update_field": 21, "store_field": 22},
+ ],
}
+ input_data_1 = {"extra_field": 2}
+ input_data_2 = {"list_field": [{"update_field": 31}, {"update_field": 41}]}
updated_data_list = [
- {'update_field': 31, 'store_field': 12},
- {'update_field': 41, 'store_field': 22},
+ {"update_field": 31, "store_field": 12},
+ {"update_field": 41, "store_field": 22},
]
# data_1
@@ -528,7 +506,7 @@ class TestSerializerPartialUsage:
serializer = Serializer(instance, data=input_data_2, partial=True)
assert serializer.is_valid()
- for index, data in enumerate(serializer.validated_data['list_field']):
+ for index, data in enumerate(serializer.validated_data["list_field"]):
for key, value in data.items():
assert value == updated_data_list[index][key]
@@ -557,7 +535,7 @@ class TestEmptyListSerializer:
def test_nested_serializer_with_list_multipart(self):
# pass an "empty" QueryDict to the serializer (should be the same as an empty array)
- input_data = QueryDict('')
+ input_data = QueryDict("")
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
diff --git a/tests/test_serializer_nested.py b/tests/test_serializer_nested.py
index 1cd0caf85..35e88f5f9 100644
--- a/tests/test_serializer_nested.py
+++ b/tests/test_serializer_nested.py
@@ -15,29 +15,14 @@ class TestNestedSerializer:
self.Serializer = TestSerializer
def test_nested_validate(self):
- input_data = {
- 'nested': {
- 'one': '1',
- 'two': '2',
- }
- }
- expected_data = {
- 'nested': {
- 'one': 1,
- 'two': 2,
- }
- }
+ input_data = {"nested": {"one": "1", "two": "2"}}
+ expected_data = {"nested": {"one": 1, "two": 2}}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
assert serializer.validated_data == expected_data
def test_nested_serialize_empty(self):
- expected_data = {
- 'nested': {
- 'one': None,
- 'two': None
- }
- }
+ expected_data = {"nested": {"one": None, "two": None}}
serializer = self.Serializer()
assert serializer.data == expected_data
@@ -45,7 +30,7 @@ class TestNestedSerializer:
data = None
serializer = self.Serializer(data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'non_field_errors': ['No data provided']}
+ assert serializer.errors == {"non_field_errors": ["No data provided"]}
class TestNotRequiredNestedSerializer:
@@ -63,16 +48,16 @@ class TestNotRequiredNestedSerializer:
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- input_data = {'nested': {'one': '1'}}
+ input_data = {"nested": {"one": "1"}}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
def test_multipart_validate(self):
- input_data = QueryDict('')
+ input_data = QueryDict("")
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- input_data = QueryDict('nested[one]=1')
+ input_data = QueryDict("nested[one]=1")
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
@@ -92,16 +77,16 @@ class TestNestedSerializerWithMany:
def test_null_allowed_if_allow_null_is_set(self):
input_data = {
- 'allow_null': None,
- 'not_allow_null': [{'example': '2'}, {'example': '3'}],
- 'allow_empty': [{'example': '2'}],
- 'not_allow_empty': [{'example': '2'}],
+ "allow_null": None,
+ "not_allow_null": [{"example": "2"}, {"example": "3"}],
+ "allow_empty": [{"example": "2"}],
+ "not_allow_empty": [{"example": "2"}],
}
expected_data = {
- 'allow_null': None,
- 'not_allow_null': [{'example': 2}, {'example': 3}],
- 'allow_empty': [{'example': 2}],
- 'not_allow_empty': [{'example': 2}],
+ "allow_null": None,
+ "not_allow_null": [{"example": 2}, {"example": 3}],
+ "allow_empty": [{"example": 2}],
+ "not_allow_empty": [{"example": 2}],
}
serializer = self.Serializer(data=input_data)
@@ -110,16 +95,16 @@ class TestNestedSerializerWithMany:
def test_null_is_not_allowed_if_allow_null_is_not_set(self):
input_data = {
- 'allow_null': None,
- 'not_allow_null': None,
- 'allow_empty': [{'example': '2'}],
- 'not_allow_empty': [{'example': '2'}],
+ "allow_null": None,
+ "not_allow_null": None,
+ "allow_empty": [{"example": "2"}],
+ "not_allow_empty": [{"example": "2"}],
}
serializer = self.Serializer(data=input_data)
assert not serializer.is_valid()
- expected_errors = {'not_allow_null': [serializer.error_messages['null']]}
+ expected_errors = {"not_allow_null": [serializer.error_messages["null"]]}
assert serializer.errors == expected_errors
def test_run_the_field_validation_even_if_the_field_is_null(self):
@@ -131,10 +116,10 @@ class TestNestedSerializerWithMany:
return value
input_data = {
- 'allow_null': None,
- 'not_allow_null': [{'example': 2}],
- 'allow_empty': [{'example': 2}],
- 'not_allow_empty': [{'example': 2}],
+ "allow_null": None,
+ "not_allow_null": [{"example": 2}],
+ "allow_empty": [{"example": 2}],
+ "not_allow_empty": [{"example": 2}],
}
serializer = TestSerializer(data=input_data)
@@ -144,16 +129,16 @@ class TestNestedSerializerWithMany:
def test_empty_allowed_if_allow_empty_is_set(self):
input_data = {
- 'allow_null': [{'example': '2'}],
- 'not_allow_null': [{'example': '2'}],
- 'allow_empty': [],
- 'not_allow_empty': [{'example': '2'}],
+ "allow_null": [{"example": "2"}],
+ "not_allow_null": [{"example": "2"}],
+ "allow_empty": [],
+ "not_allow_empty": [{"example": "2"}],
}
expected_data = {
- 'allow_null': [{'example': 2}],
- 'not_allow_null': [{'example': 2}],
- 'allow_empty': [],
- 'not_allow_empty': [{'example': 2}],
+ "allow_null": [{"example": 2}],
+ "not_allow_null": [{"example": 2}],
+ "allow_empty": [],
+ "not_allow_empty": [{"example": 2}],
}
serializer = self.Serializer(data=input_data)
@@ -162,16 +147,22 @@ class TestNestedSerializerWithMany:
def test_empty_not_allowed_if_allow_empty_is_set_to_false(self):
input_data = {
- 'allow_null': [{'example': '2'}],
- 'not_allow_null': [{'example': '2'}],
- 'allow_empty': [],
- 'not_allow_empty': [],
+ "allow_null": [{"example": "2"}],
+ "not_allow_null": [{"example": "2"}],
+ "allow_empty": [],
+ "not_allow_empty": [],
}
serializer = self.Serializer(data=input_data)
assert not serializer.is_valid()
- expected_errors = {'not_allow_empty': {'non_field_errors': [serializers.ListSerializer.default_error_messages['empty']]}}
+ expected_errors = {
+ "not_allow_empty": {
+ "non_field_errors": [
+ serializers.ListSerializer.default_error_messages["empty"]
+ ]
+ }
+ }
assert serializer.errors == expected_errors
@@ -186,22 +177,18 @@ class TestNestedSerializerWithList:
self.Serializer = TestSerializer
def test_nested_serializer_with_list_json(self):
- input_data = {
- 'nested': {
- 'example': [1, 2],
- }
- }
+ input_data = {"nested": {"example": [1, 2]}}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- assert serializer.validated_data['nested']['example'] == {1, 2}
+ assert serializer.validated_data["nested"]["example"] == {1, 2}
def test_nested_serializer_with_list_multipart(self):
- input_data = QueryDict('nested.example=1&nested.example=2')
+ input_data = QueryDict("nested.example=1&nested.example=2")
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- assert serializer.validated_data['nested']['example'] == {1, 2}
+ assert serializer.validated_data["nested"]["example"] == {1, 2}
class TestNotRequiredNestedSerializerWithMany:
@@ -220,24 +207,24 @@ class TestNotRequiredNestedSerializerWithMany:
# request is empty, therefor 'nested' should not be in serializer.data
assert serializer.is_valid()
- assert 'nested' not in serializer.validated_data
+ assert "nested" not in serializer.validated_data
- input_data = {'nested': [{'one': '1'}, {'one': 2}]}
+ input_data = {"nested": [{"one": "1"}, {"one": 2}]}
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- assert 'nested' in serializer.validated_data
+ assert "nested" in serializer.validated_data
def test_multipart_validate(self):
# leave querydict empty
- input_data = QueryDict('')
+ input_data = QueryDict("")
serializer = self.Serializer(data=input_data)
# the querydict is empty, therefor 'nested' should not be in serializer.data
assert serializer.is_valid()
- assert 'nested' not in serializer.validated_data
+ assert "nested" not in serializer.validated_data
- input_data = QueryDict('nested[0]one=1&nested[1]one=2')
+ input_data = QueryDict("nested[0]one=1&nested[1]one=2")
serializer = self.Serializer(data=input_data)
assert serializer.is_valid()
- assert 'nested' in serializer.validated_data
+ assert "nested" in serializer.validated_data
diff --git a/tests/test_settings.py b/tests/test_settings.py
index 51e9751b2..84d30e56f 100644
--- a/tests/test_settings.py
+++ b/tests/test_settings.py
@@ -10,11 +10,9 @@ class TestSettings(TestCase):
"""
Make sure import errors are captured and raised sensibly.
"""
- settings = APISettings({
- 'DEFAULT_RENDERER_CLASSES': [
- 'tests.invalid_module.InvalidClassName'
- ]
- })
+ settings = APISettings(
+ {"DEFAULT_RENDERER_CLASSES": ["tests.invalid_module.InvalidClassName"]}
+ )
with self.assertRaises(ImportError):
settings.DEFAULT_RENDERER_CLASSES
@@ -24,9 +22,7 @@ class TestSettings(TestCase):
is set.
"""
with self.assertRaises(RuntimeError):
- APISettings({
- 'MAX_PAGINATE_BY': 100
- })
+ APISettings({"MAX_PAGINATE_BY": 100})
def test_compatibility_with_override_settings(self):
"""
@@ -40,7 +36,7 @@ class TestSettings(TestCase):
"""
assert api_settings.PAGE_SIZE is None, "Checking a known default should be None"
- with override_settings(REST_FRAMEWORK={'PAGE_SIZE': 10}):
+ with override_settings(REST_FRAMEWORK={"PAGE_SIZE": 10}):
assert api_settings.PAGE_SIZE == 10, "Setting should have been updated"
assert api_settings.PAGE_SIZE is None, "Setting should have been restored"
@@ -48,12 +44,10 @@ class TestSettings(TestCase):
class TestSettingTypes(TestCase):
def test_settings_consistently_coerced_to_list(self):
- settings = APISettings({
- 'DEFAULT_THROTTLE_CLASSES': ('rest_framework.throttling.BaseThrottle',)
- })
+ settings = APISettings(
+ {"DEFAULT_THROTTLE_CLASSES": ("rest_framework.throttling.BaseThrottle",)}
+ )
self.assertTrue(isinstance(settings.DEFAULT_THROTTLE_CLASSES, list))
- settings = APISettings({
- 'DEFAULT_THROTTLE_CLASSES': ()
- })
+ settings = APISettings({"DEFAULT_THROTTLE_CLASSES": ()})
self.assertTrue(isinstance(settings.DEFAULT_THROTTLE_CLASSES, list))
diff --git a/tests/test_status.py b/tests/test_status.py
index 1cd6e229e..e948557ff 100644
--- a/tests/test_status.py
+++ b/tests/test_status.py
@@ -3,8 +3,11 @@ from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.status import (
- is_client_error, is_informational, is_redirect, is_server_error,
- is_success
+ is_client_error,
+ is_informational,
+ is_redirect,
+ is_server_error,
+ is_success,
)
diff --git a/tests/test_templates.py b/tests/test_templates.py
index 19f511b96..60a8587e0 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -4,14 +4,14 @@ from django.shortcuts import render
def test_base_template_with_context():
- context = {'request': True, 'csrf_token': 'TOKEN'}
- result = render({}, 'rest_framework/base.html', context=context)
- assert re.search(r'\bcsrfToken: "TOKEN"', result.content.decode('utf-8'))
+ context = {"request": True, "csrf_token": "TOKEN"}
+ result = render({}, "rest_framework/base.html", context=context)
+ assert re.search(r'\bcsrfToken: "TOKEN"', result.content.decode("utf-8"))
def test_base_template_with_no_context():
# base.html should be renderable with no context,
# so it can be easily extended.
- result = render({}, 'rest_framework/base.html')
+ result = render({}, "rest_framework/base.html")
# note that this response will not include a valid CSRF token
- assert re.search(r'\bcsrfToken: ""', result.content.decode('utf-8'))
+ assert re.search(r'\bcsrfToken: ""', result.content.decode("utf-8"))
diff --git a/tests/test_templatetags.py b/tests/test_templatetags.py
index 45bfd4aeb..74ef16a12 100644
--- a/tests/test_templatetags.py
+++ b/tests/test_templatetags.py
@@ -10,11 +10,18 @@ from rest_framework.compat import coreapi, coreschema
from rest_framework.relations import Hyperlink
from rest_framework.templatetags import rest_framework
from rest_framework.templatetags.rest_framework import (
- add_nested_class, add_query_param, as_string, break_long_headers,
- format_value, get_pagination_html, schema_links, urlize_quoted_links
+ add_nested_class,
+ add_query_param,
+ as_string,
+ break_long_headers,
+ format_value,
+ get_pagination_html,
+ schema_links,
+ urlize_quoted_links,
)
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
@@ -24,16 +31,15 @@ def format_html(html):
:param html: raw HTML text to be formatted
:return: Cleaned HTML with no newlines or spaces
"""
- return html.replace('\n', '').replace(' ', '')
+ return html.replace("\n", "").replace(" ", "")
class TemplateTagTests(TestCase):
-
def test_add_query_param_with_non_latin_character(self):
# Ensure we don't double-escape non-latin characters
# that are present in the querystring.
# See #1314.
- request = factory.get("/", {'q': '查询'})
+ request = factory.get("/", {"q": "查询"})
json_url = add_query_param(request, "format", "json")
self.assertIn("q=%E6%9F%A5%E8%AF%A2", json_url)
self.assertIn("format=json", json_url)
@@ -42,32 +48,32 @@ class TemplateTagTests(TestCase):
"""
Tests format_value with booleans and None
"""
- self.assertEqual(format_value(True), 'true
')
- self.assertEqual(format_value(False), 'false
')
- self.assertEqual(format_value(None), 'null
')
+ self.assertEqual(format_value(True), "true
")
+ self.assertEqual(format_value(False), "false
")
+ self.assertEqual(format_value(None), "null
")
def test_format_value_hyperlink(self):
"""
Tests format_value with a URL
"""
- url = 'http://url.com'
- name = 'name_of_url'
+ url = "http://url.com"
+ name = "name_of_url"
hyperlink = Hyperlink(url, name)
- self.assertEqual(format_value(hyperlink), '%s' % (url, name))
+ self.assertEqual(format_value(hyperlink), "%s" % (url, name))
def test_format_value_list(self):
"""
Tests format_value with a list of strings
"""
- list_items = ['item1', 'item2', 'item3']
- self.assertEqual(format_value(list_items), '\n item1, item2, item3\n')
- self.assertEqual(format_value([]), '\n\n')
+ list_items = ["item1", "item2", "item3"]
+ self.assertEqual(format_value(list_items), "\n item1, item2, item3\n")
+ self.assertEqual(format_value([]), "\n\n")
def test_format_value_dict(self):
"""
Tests format_value with a dict
"""
- test_dict = {'a': 'b'}
+ test_dict = {"a": "b"}
expected_dict_format = """
@@ -78,15 +84,14 @@ class TemplateTagTests(TestCase):
"""
self.assertEqual(
- format_html(format_value(test_dict)),
- format_html(expected_dict_format)
+ format_html(format_value(test_dict)), format_html(expected_dict_format)
)
def test_format_value_table(self):
"""
Tests format_value with a list of lists/dicts
"""
- list_of_lists = [['list1'], ['list2'], ['list3']]
+ list_of_lists = [["list1"], ["list2"], ["list3"]]
expected_list_format = """
@@ -105,8 +110,7 @@ class TemplateTagTests(TestCase):
"""
self.assertEqual(
- format_html(format_value(list_of_lists)),
- format_html(expected_list_format)
+ format_html(format_value(list_of_lists)), format_html(expected_list_format)
)
expected_dict_format = """
@@ -154,40 +158,48 @@ class TemplateTagTests(TestCase):
"""
- list_of_dicts = [{'item1': 'value1'}, {'item2': 'value2'}, {'item3': 'value3'}]
+ list_of_dicts = [{"item1": "value1"}, {"item2": "value2"}, {"item3": "value3"}]
self.assertEqual(
- format_html(format_value(list_of_dicts)),
- format_html(expected_dict_format)
+ format_html(format_value(list_of_dicts)), format_html(expected_dict_format)
)
def test_format_value_simple_string(self):
"""
Tests format_value with a simple string
"""
- simple_string = 'this is an example of a string'
+ simple_string = "this is an example of a string"
self.assertEqual(format_value(simple_string), simple_string)
def test_format_value_string_hyperlink(self):
"""
Tests format_value with a url
"""
- url = 'http://www.example.com'
- self.assertEqual(format_value(url), 'http://www.example.com')
+ url = "http://www.example.com"
+ self.assertEqual(
+ format_value(url),
+ 'http://www.example.com',
+ )
def test_format_value_string_email(self):
"""
Tests format_value with an email address
"""
- email = 'something@somewhere.com'
- self.assertEqual(format_value(email), 'something@somewhere.com')
+ email = "something@somewhere.com"
+ self.assertEqual(
+ format_value(email),
+ 'something@somewhere.com',
+ )
def test_format_value_string_newlines(self):
"""
Tests format_value with a string with newline characters
:return:
"""
- text = 'Dear user, \n this is a message \n from,\nsomeone'
- self.assertEqual(format_value(text), 'Dear user, \n this is a message \n from,\nsomeone
')
+ text = "Dear user, \n this is a message \n from,\nsomeone"
+ self.assertEqual(
+ format_value(text),
+ "Dear user, \n this is a message \n from,\nsomeone
",
+ )
def test_format_value_object(self):
"""
@@ -200,29 +212,19 @@ class TemplateTagTests(TestCase):
"""
Tests that add_nested_class returns the proper class
"""
- positive_cases = [
- [['item']],
- [{'item1': 'value1'}],
- {'item1': 'value1'}
- ]
+ positive_cases = [[["item"]], [{"item1": "value1"}], {"item1": "value1"}]
- negative_cases = [
- ['list'],
- '',
- None,
- True,
- False
- ]
+ negative_cases = [["list"], "", None, True, False]
for case in positive_cases:
- self.assertEqual(add_nested_class(case), 'class=nested')
+ self.assertEqual(add_nested_class(case), "class=nested")
for case in negative_cases:
- self.assertEqual(add_nested_class(case), '')
+ self.assertEqual(add_nested_class(case), "")
def test_as_string_with_none(self):
result = as_string(None)
- assert result == ''
+ assert result == ""
def test_get_pagination_html(self):
class MockPager(object):
@@ -237,8 +239,8 @@ class TemplateTagTests(TestCase):
assert pager.called is True
def test_break_long_lines(self):
- header = 'long test header,' * 20
- expected_header = '
' + ',
'.join(header.split(','))
+ header = "long test header," * 20
+ expected_header = "
" + ",
".join(header.split(","))
assert break_long_headers(header) == expected_header
@@ -251,21 +253,13 @@ class Issue1386Tests(TestCase):
"""
Test function urlize_quoted_links with different args
"""
- correct_urls = [
- "asdf.com",
- "asdf.net",
- "www.as_df.org",
- "as.d8f.ghj8.gov",
- ]
+ correct_urls = ["asdf.com", "asdf.net", "www.as_df.org", "as.d8f.ghj8.gov"]
for i in correct_urls:
res = urlize_quoted_links(i)
self.assertNotEqual(res, i)
self.assertIn(i, res)
- incorrect_urls = [
- "mailto://asdf@fdf.com",
- "asdf.netnet",
- ]
+ incorrect_urls = ["mailto://asdf@fdf.com", "asdf.netnet"]
for i in incorrect_urls:
res = urlize_quoted_links(i)
self.assertEqual(i, res)
@@ -279,7 +273,7 @@ class Issue1386Tests(TestCase):
old = rest_framework.smart_urlquote
rest_framework.smart_urlquote = mock_smart_urlquote
- assert rest_framework.smart_urlquote_wrapper('test') is None
+ assert rest_framework.smart_urlquote_wrapper("test") is None
rest_framework.smart_urlquote = old
@@ -287,6 +281,7 @@ class URLizerTests(TestCase):
"""
Test if JSON URLs are transformed into links well
"""
+
def _urlize_dict_check(self, data):
"""
For all items in dict test assert that the value is urlized key
@@ -299,337 +294,396 @@ class URLizerTests(TestCase):
Test if JSON URLs are transformed into links well
"""
data = {}
- data['"url": "http://api/users/1/", '] = \
- '"url": "http://api/users/1/", '
- data['"foo_set": [\n "http://api/foos/1/"\n], '] = \
- '"foo_set": [\n "http://api/foos/1/"\n], '
+ data[
+ '"url": "http://api/users/1/", '
+ ] = '"url": "http://api/users/1/", '
+ data[
+ '"foo_set": [\n "http://api/foos/1/"\n], '
+ ] = '"foo_set": [\n "http://api/foos/1/"\n], '
self._urlize_dict_check(data)
def test_template_render_with_autoescape(self):
"""
Test that HTML is correctly escaped in Browsable API views.
"""
- template = Template("{% load rest_framework %}{{ content|urlize_quoted_links }}")
- rendered = template.render(Context({'content': ' http://example.com'}))
- assert rendered == '<script>alert()</script>' \
- ' http://example.com'
+ template = Template(
+ "{% load rest_framework %}{{ content|urlize_quoted_links }}"
+ )
+ rendered = template.render(
+ Context({"content": " http://example.com"})
+ )
+ assert (
+ rendered == "<script>alert()</script>"
+ ' http://example.com'
+ )
def test_template_render_with_noautoescape(self):
"""
Test if the autoescape value is getting passed to urlize_quoted_links filter.
"""
- template = Template("{% load rest_framework %}"
- "{% autoescape off %}{{ content|urlize_quoted_links }}"
- "{% endautoescape %}")
- rendered = template.render(Context({'content': ' "http://example.com" '}))
- assert rendered == ' "http://example.com" '
+ template = Template(
+ "{% load rest_framework %}"
+ "{% autoescape off %}{{ content|urlize_quoted_links }}"
+ "{% endautoescape %}"
+ )
+ rendered = template.render(
+ Context({"content": ' "http://example.com" '})
+ )
+ assert (
+ rendered
+ == ' "http://example.com" '
+ )
-@unittest.skipUnless(coreapi, 'coreapi is not installed')
+@unittest.skipUnless(coreapi, "coreapi is not installed")
class SchemaLinksTests(TestCase):
-
def test_schema_with_empty_links(self):
schema = coreapi.Document(
- url='',
- title='Example API',
- content={
- 'users': {
- 'list': {}
- }
- }
+ url="", title="Example API", content={"users": {"list": {}}}
)
- section = schema['users']
+ section = schema["users"]
flat_links = schema_links(section)
assert len(flat_links) is 0
def test_single_action(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'users': {
- 'list': coreapi.Link(
- url='/users/',
- action='get',
- fields=[]
- )
- }
- }
+ "users": {"list": coreapi.Link(url="/users/", action="get", fields=[])}
+ },
)
- section = schema['users']
+ section = schema["users"]
flat_links = schema_links(section)
assert len(flat_links) is 1
- assert 'list' in flat_links
+ assert "list" in flat_links
def test_default_actions(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'users': {
- 'create': coreapi.Link(
- url='/users/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/users/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/users/{id}/',
- action='get',
+ "users": {
+ "create": coreapi.Link(url="/users/", action="post", fields=[]),
+ "list": coreapi.Link(url="/users/", action="get", fields=[]),
+ "read": coreapi.Link(
+ url="/users/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'update': coreapi.Link(
- url='/users/{id}/',
- action='patch',
+ "update": coreapi.Link(
+ url="/users/{id}/",
+ action="patch",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
}
- }
+ },
)
- section = schema['users']
+ section = schema["users"]
flat_links = schema_links(section)
assert len(flat_links) is 4
- assert 'list' in flat_links
- assert 'create' in flat_links
- assert 'read' in flat_links
- assert 'update' in flat_links
+ assert "list" in flat_links
+ assert "create" in flat_links
+ assert "read" in flat_links
+ assert "update" in flat_links
def test_default_actions_and_single_custom_action(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'users': {
- 'create': coreapi.Link(
- url='/users/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/users/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/users/{id}/',
- action='get',
+ "users": {
+ "create": coreapi.Link(url="/users/", action="post", fields=[]),
+ "list": coreapi.Link(url="/users/", action="get", fields=[]),
+ "read": coreapi.Link(
+ url="/users/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'update': coreapi.Link(
- url='/users/{id}/',
- action='patch',
+ "update": coreapi.Link(
+ url="/users/{id}/",
+ action="patch",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'friends': coreapi.Link(
- url='/users/{id}/friends',
- action='get',
+ "friends": coreapi.Link(
+ url="/users/{id}/friends",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
}
- }
+ },
)
- section = schema['users']
+ section = schema["users"]
flat_links = schema_links(section)
assert len(flat_links) is 5
- assert 'list' in flat_links
- assert 'create' in flat_links
- assert 'read' in flat_links
- assert 'update' in flat_links
- assert 'friends' in flat_links
+ assert "list" in flat_links
+ assert "create" in flat_links
+ assert "read" in flat_links
+ assert "update" in flat_links
+ assert "friends" in flat_links
def test_default_actions_and_single_custom_action_two_methods(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'users': {
- 'create': coreapi.Link(
- url='/users/',
- action='post',
- fields=[]
- ),
- 'list': coreapi.Link(
- url='/users/',
- action='get',
- fields=[]
- ),
- 'read': coreapi.Link(
- url='/users/{id}/',
- action='get',
+ "users": {
+ "create": coreapi.Link(url="/users/", action="post", fields=[]),
+ "list": coreapi.Link(url="/users/", action="get", fields=[]),
+ "read": coreapi.Link(
+ url="/users/{id}/",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'update': coreapi.Link(
- url='/users/{id}/',
- action='patch',
+ "update": coreapi.Link(
+ url="/users/{id}/",
+ action="patch",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'friends': {
- 'list': coreapi.Link(
- url='/users/{id}/friends',
- action='get',
+ "friends": {
+ "list": coreapi.Link(
+ url="/users/{id}/friends",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'create': coreapi.Link(
- url='/users/{id}/friends',
- action='post',
+ "create": coreapi.Link(
+ url="/users/{id}/friends",
+ action="post",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- }
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
+ },
}
- }
+ },
)
- section = schema['users']
+ section = schema["users"]
flat_links = schema_links(section)
assert len(flat_links) is 6
- assert 'list' in flat_links
- assert 'create' in flat_links
- assert 'read' in flat_links
- assert 'update' in flat_links
- assert 'friends > list' in flat_links
- assert 'friends > create' in flat_links
+ assert "list" in flat_links
+ assert "create" in flat_links
+ assert "read" in flat_links
+ assert "update" in flat_links
+ assert "friends > list" in flat_links
+ assert "friends > create" in flat_links
def test_multiple_nested_routes(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'animals': {
- 'dog': {
- 'vet': {
- 'list': coreapi.Link(
- url='/animals/dog/{id}/vet',
- action='get',
+ "animals": {
+ "dog": {
+ "vet": {
+ "list": coreapi.Link(
+ url="/animals/dog/{id}/vet",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
)
},
- 'read': coreapi.Link(
- url='/animals/dog/{id}',
- action='get',
+ "read": coreapi.Link(
+ url="/animals/dog/{id}",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- },
- 'cat': {
- 'list': coreapi.Link(
- url='/animals/cat/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'create': coreapi.Link(
- url='/aniamls/cat',
- action='post',
- fields=[]
- )
- }
+ },
+ "cat": {
+ "list": coreapi.Link(
+ url="/animals/cat/",
+ action="get",
+ fields=[
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
+ "create": coreapi.Link(
+ url="/aniamls/cat", action="post", fields=[]
+ ),
+ },
}
- }
+ },
)
- section = schema['animals']
+ section = schema["animals"]
flat_links = schema_links(section)
assert len(flat_links) is 4
- assert 'cat > create' in flat_links
- assert 'cat > list' in flat_links
- assert 'dog > read' in flat_links
- assert 'dog > vet > list' in flat_links
+ assert "cat > create" in flat_links
+ assert "cat > list" in flat_links
+ assert "dog > read" in flat_links
+ assert "dog > vet > list" in flat_links
def test_multiple_resources_with_multiple_nested_routes(self):
schema = coreapi.Document(
- url='',
- title='Example API',
+ url="",
+ title="Example API",
content={
- 'animals': {
- 'dog': {
- 'vet': {
- 'list': coreapi.Link(
- url='/animals/dog/{id}/vet',
- action='get',
+ "animals": {
+ "dog": {
+ "vet": {
+ "list": coreapi.Link(
+ url="/animals/dog/{id}/vet",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
)
},
- 'read': coreapi.Link(
- url='/animals/dog/{id}',
- action='get',
+ "read": coreapi.Link(
+ url="/animals/dog/{id}",
+ action="get",
fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- },
- 'cat': {
- 'list': coreapi.Link(
- url='/animals/cat/',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
+ },
+ "cat": {
+ "list": coreapi.Link(
+ url="/animals/cat/",
+ action="get",
+ fields=[
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ ),
+ "create": coreapi.Link(
+ url="/aniamls/cat", action="post", fields=[]
+ ),
+ },
+ },
+ "farmers": {
+ "silo": {
+ "soy": {
+ "list": coreapi.Link(
+ url="/farmers/silo/{id}/soy",
+ action="get",
+ fields=[
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
+ )
+ },
+ "list": coreapi.Link(
+ url="/farmers/silo",
+ action="get",
+ fields=[
+ coreapi.Field(
+ "id",
+ required=True,
+ location="path",
+ schema=coreschema.String(),
+ )
+ ],
),
- 'create': coreapi.Link(
- url='/aniamls/cat',
- action='post',
- fields=[]
- )
}
},
- 'farmers': {
- 'silo': {
- 'soy': {
- 'list': coreapi.Link(
- url='/farmers/silo/{id}/soy',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- },
- 'list': coreapi.Link(
- url='/farmers/silo',
- action='get',
- fields=[
- coreapi.Field('id', required=True, location='path', schema=coreschema.String())
- ]
- )
- }
- }
- }
+ },
)
- section = schema['animals']
+ section = schema["animals"]
flat_links = schema_links(section)
assert len(flat_links) is 4
- assert 'cat > create' in flat_links
- assert 'cat > list' in flat_links
- assert 'dog > read' in flat_links
- assert 'dog > vet > list' in flat_links
+ assert "cat > create" in flat_links
+ assert "cat > list" in flat_links
+ assert "dog > read" in flat_links
+ assert "dog > vet > list" in flat_links
- section = schema['farmers']
+ section = schema["farmers"]
flat_links = schema_links(section)
assert len(flat_links) is 2
- assert 'silo > list' in flat_links
- assert 'silo > soy > list' in flat_links
+ assert "silo > list" in flat_links
+ assert "silo > soy > list" in flat_links
diff --git a/tests/test_testing.py b/tests/test_testing.py
index 7868f724c..98348fb3b 100644
--- a/tests/test_testing.py
+++ b/tests/test_testing.py
@@ -12,37 +12,40 @@ from rest_framework import fields, serializers
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.test import (
- APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate
+ APIClient,
+ APIRequestFactory,
+ URLPatternsTestCase,
+ force_authenticate,
)
-@api_view(['GET', 'POST'])
+@api_view(["GET", "POST"])
def view(request):
- return Response({
- 'auth': request.META.get('HTTP_AUTHORIZATION', b''),
- 'user': request.user.username
- })
+ return Response(
+ {
+ "auth": request.META.get("HTTP_AUTHORIZATION", b""),
+ "user": request.user.username,
+ }
+ )
-@api_view(['GET', 'POST'])
+@api_view(["GET", "POST"])
def session_view(request):
- active_session = request.session.get('active_session', False)
- request.session['active_session'] = True
- return Response({
- 'active_session': active_session
- })
+ active_session = request.session.get("active_session", False)
+ request.session["active_session"] = True
+ return Response({"active_session": active_session})
-@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
+@api_view(["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
def redirect_view(request):
- return redirect('/view/')
+ return redirect("/view/")
class BasicSerializer(serializers.Serializer):
flag = fields.BooleanField(default=lambda: True)
-@api_view(['POST'])
+@api_view(["POST"])
def post_view(request):
serializer = BasicSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
@@ -50,14 +53,14 @@ def post_view(request):
urlpatterns = [
- url(r'^view/$', view),
- url(r'^session-view/$', session_view),
- url(r'^redirect-view/$', redirect_view),
- url(r'^post-view/$', post_view)
+ url(r"^view/$", view),
+ url(r"^session-view/$", session_view),
+ url(r"^redirect-view/$", redirect_view),
+ url(r"^post-view/$", post_view),
]
-@override_settings(ROOT_URLCONF='tests.test_testing')
+@override_settings(ROOT_URLCONF="tests.test_testing")
class TestAPITestClient(TestCase):
def setUp(self):
self.client = APIClient()
@@ -66,47 +69,47 @@ class TestAPITestClient(TestCase):
"""
Setting `.credentials()` adds the required headers to each request.
"""
- self.client.credentials(HTTP_AUTHORIZATION='example')
+ self.client.credentials(HTTP_AUTHORIZATION="example")
for _ in range(0, 3):
- response = self.client.get('/view/')
- assert response.data['auth'] == 'example'
+ response = self.client.get("/view/")
+ assert response.data["auth"] == "example"
def test_force_authenticate(self):
"""
Setting `.force_authenticate()` forcibly authenticates each request.
"""
- user = User.objects.create_user('example', 'example@example.com')
+ user = User.objects.create_user("example", "example@example.com")
self.client.force_authenticate(user)
- response = self.client.get('/view/')
- assert response.data['user'] == 'example'
+ response = self.client.get("/view/")
+ assert response.data["user"] == "example"
def test_force_authenticate_with_sessions(self):
"""
Setting `.force_authenticate()` forcibly authenticates each request.
"""
- user = User.objects.create_user('example', 'example@example.com')
+ user = User.objects.create_user("example", "example@example.com")
self.client.force_authenticate(user)
# First request does not yet have an active session
- response = self.client.get('/session-view/')
- assert response.data['active_session'] is False
+ response = self.client.get("/session-view/")
+ assert response.data["active_session"] is False
# Subsequent requests have an active session
- response = self.client.get('/session-view/')
- assert response.data['active_session'] is True
+ response = self.client.get("/session-view/")
+ assert response.data["active_session"] is True
# Force authenticating as `None` should also logout the user session.
self.client.force_authenticate(None)
- response = self.client.get('/session-view/')
- assert response.data['active_session'] is False
+ response = self.client.get("/session-view/")
+ assert response.data["active_session"] is False
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
"""
- User.objects.create_user('example', 'example@example.com', 'password')
- self.client.login(username='example', password='password')
- response = self.client.post('/view/')
+ User.objects.create_user("example", "example@example.com", "password")
+ self.client.login(username="example", password="password")
+ response = self.client.post("/view/")
assert response.status_code == 200
def test_explicitly_enforce_csrf_checks(self):
@@ -114,10 +117,10 @@ class TestAPITestClient(TestCase):
The test client can enforce CSRF checks.
"""
client = APIClient(enforce_csrf_checks=True)
- User.objects.create_user('example', 'example@example.com', 'password')
- client.login(username='example', password='password')
- response = client.post('/view/')
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ User.objects.create_user("example", "example@example.com", "password")
+ client.login(username="example", password="password")
+ response = client.post("/view/")
+ expected = {"detail": "CSRF Failed: CSRF cookie not set."}
assert response.status_code == 403
assert response.data == expected
@@ -125,62 +128,62 @@ class TestAPITestClient(TestCase):
"""
`logout()` resets stored credentials
"""
- self.client.credentials(HTTP_AUTHORIZATION='example')
- response = self.client.get('/view/')
- assert response.data['auth'] == 'example'
+ self.client.credentials(HTTP_AUTHORIZATION="example")
+ response = self.client.get("/view/")
+ assert response.data["auth"] == "example"
self.client.logout()
- response = self.client.get('/view/')
- assert response.data['auth'] == b''
+ response = self.client.get("/view/")
+ assert response.data["auth"] == b""
def test_logout_resets_force_authenticate(self):
"""
`logout()` resets any `force_authenticate`
"""
- user = User.objects.create_user('example', 'example@example.com', 'password')
+ user = User.objects.create_user("example", "example@example.com", "password")
self.client.force_authenticate(user)
- response = self.client.get('/view/')
- assert response.data['user'] == 'example'
+ response = self.client.get("/view/")
+ assert response.data["user"] == "example"
self.client.logout()
- response = self.client.get('/view/')
- assert response.data['user'] == ''
+ response = self.client.get("/view/")
+ assert response.data["user"] == ""
def test_follow_redirect(self):
"""
Follow redirect by setting follow argument.
"""
- response = self.client.get('/redirect-view/')
+ response = self.client.get("/redirect-view/")
assert response.status_code == 302
- response = self.client.get('/redirect-view/', follow=True)
+ response = self.client.get("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
- response = self.client.post('/redirect-view/')
+ response = self.client.post("/redirect-view/")
assert response.status_code == 302
- response = self.client.post('/redirect-view/', follow=True)
+ response = self.client.post("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
- response = self.client.put('/redirect-view/')
+ response = self.client.put("/redirect-view/")
assert response.status_code == 302
- response = self.client.put('/redirect-view/', follow=True)
+ response = self.client.put("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
- response = self.client.patch('/redirect-view/')
+ response = self.client.patch("/redirect-view/")
assert response.status_code == 302
- response = self.client.patch('/redirect-view/', follow=True)
+ response = self.client.patch("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
- response = self.client.delete('/redirect-view/')
+ response = self.client.delete("/redirect-view/")
assert response.status_code == 302
- response = self.client.delete('/redirect-view/', follow=True)
+ response = self.client.delete("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
- response = self.client.options('/redirect-view/')
+ response = self.client.options("/redirect-view/")
assert response.status_code == 302
- response = self.client.options('/redirect-view/', follow=True)
+ response = self.client.options("/redirect-view/", follow=True)
assert response.redirect_chain is not None
assert response.status_code == 200
@@ -190,15 +193,15 @@ class TestAPITestClient(TestCase):
error if the user attempts to do so.
"""
self.assertRaises(
- AssertionError, self.client.post,
- path='/view/', data={'valid': 123, 'invalid': {'a': 123}}
+ AssertionError,
+ self.client.post,
+ path="/view/",
+ data={"valid": 123, "invalid": {"a": 123}},
)
def test_empty_post_uses_default_boolean_value(self):
response = self.client.post(
- '/post-view/',
- data=None,
- content_type='application/json'
+ "/post-view/", data=None, content_type="application/json"
)
assert response.status_code == 200
assert response.data == {"flag": True}
@@ -209,9 +212,9 @@ class TestAPIRequestFactory(TestCase):
"""
By default, the test client is CSRF exempt.
"""
- user = User.objects.create_user('example', 'example@example.com', 'password')
+ user = User.objects.create_user("example", "example@example.com", "password")
factory = APIRequestFactory()
- request = factory.post('/view/')
+ request = factory.post("/view/")
request.user = user
response = view(request)
assert response.status_code == 200
@@ -220,12 +223,12 @@ class TestAPIRequestFactory(TestCase):
"""
The test client can enforce CSRF checks.
"""
- user = User.objects.create_user('example', 'example@example.com', 'password')
+ user = User.objects.create_user("example", "example@example.com", "password")
factory = APIRequestFactory(enforce_csrf_checks=True)
- request = factory.post('/view/')
+ request = factory.post("/view/")
request.user = user
response = view(request)
- expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
+ expected = {"detail": "CSRF Failed: CSRF cookie not set."}
assert response.status_code == 403
assert response.data == expected
@@ -236,59 +239,60 @@ class TestAPIRequestFactory(TestCase):
"""
factory = APIRequestFactory()
self.assertRaises(
- AssertionError, factory.post,
- path='/view/', data={'example': 1}, format='xml'
+ AssertionError,
+ factory.post,
+ path="/view/",
+ data={"example": 1},
+ format="xml",
)
def test_force_authenticate(self):
"""
Setting `force_authenticate()` forcibly authenticates the request.
"""
- user = User.objects.create_user('example', 'example@example.com')
+ user = User.objects.create_user("example", "example@example.com")
factory = APIRequestFactory()
- request = factory.get('/view')
+ request = factory.get("/view")
force_authenticate(request, user=user)
response = view(request)
- assert response.data['user'] == 'example'
+ assert response.data["user"] == "example"
def test_upload_file(self):
# This is a 1x1 black png
- simple_png = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82')
- simple_png.name = 'test.png'
+ simple_png = BytesIO(
+ b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00\x00\x00IEND\xaeB`\x82"
+ )
+ simple_png.name = "test.png"
factory = APIRequestFactory()
- factory.post('/', data={'image': simple_png})
+ factory.post("/", data={"image": simple_png})
def test_request_factory_url_arguments(self):
"""
This is a non regression test against #1461
"""
factory = APIRequestFactory()
- request = factory.get('/view/?demo=test')
- assert dict(request.GET) == {'demo': ['test']}
- request = factory.get('/view/', {'demo': 'test'})
- assert dict(request.GET) == {'demo': ['test']}
+ request = factory.get("/view/?demo=test")
+ assert dict(request.GET) == {"demo": ["test"]}
+ request = factory.get("/view/", {"demo": "test"})
+ assert dict(request.GET) == {"demo": ["test"]}
def test_request_factory_url_arguments_with_unicode(self):
factory = APIRequestFactory()
- request = factory.get('/view/?demo=testé')
- assert dict(request.GET) == {'demo': ['testé']}
- request = factory.get('/view/', {'demo': 'testé'})
- assert dict(request.GET) == {'demo': ['testé']}
+ request = factory.get("/view/?demo=testé")
+ assert dict(request.GET) == {"demo": ["testé"]}
+ request = factory.get("/view/", {"demo": "testé"})
+ assert dict(request.GET) == {"demo": ["testé"]}
def test_empty_request_content_type(self):
factory = APIRequestFactory()
request = factory.post(
- '/post-view/',
- data=None,
- content_type='application/json',
+ "/post-view/", data=None, content_type="application/json"
)
- assert request.META['CONTENT_TYPE'] == 'application/json'
+ assert request.META["CONTENT_TYPE"] == "application/json"
class TestUrlPatternTestCase(URLPatternsTestCase):
- urlpatterns = [
- url(r'^$', view),
- ]
+ urlpatterns = [url(r"^$", view)]
@classmethod
def setUpClass(cls):
@@ -303,10 +307,10 @@ class TestUrlPatternTestCase(URLPatternsTestCase):
assert urlpatterns is not cls.urlpatterns
def test_urlpatterns(self):
- assert self.client.get('/').status_code == 200
+ assert self.client.get("/").status_code == 200
class TestExistingPatterns(TestCase):
def test_urlpatterns(self):
# sanity test to ensure that this test module does not have a '/' route
- assert self.client.get('/').status_code == 404
+ assert self.client.get("/").status_code == 404
diff --git a/tests/test_throttling.py b/tests/test_throttling.py
index b220a33a6..99497098a 100644
--- a/tests/test_throttling.py
+++ b/tests/test_throttling.py
@@ -15,25 +15,28 @@ from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.throttling import (
- AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
- UserRateThrottle
+ AnonRateThrottle,
+ BaseThrottle,
+ ScopedRateThrottle,
+ SimpleRateThrottle,
+ UserRateThrottle,
)
from rest_framework.views import APIView
class User3SecRateThrottle(UserRateThrottle):
- rate = '3/sec'
- scope = 'seconds'
+ rate = "3/sec"
+ scope = "seconds"
class User3MinRateThrottle(UserRateThrottle):
- rate = '3/min'
- scope = 'minutes'
+ rate = "3/min"
+ scope = "minutes"
class NonTimeThrottle(BaseThrottle):
def allow_request(self, request, view):
- if not hasattr(self.__class__, 'called'):
+ if not hasattr(self.__class__, "called"):
self.__class__.called = True
return True
return False
@@ -43,21 +46,21 @@ class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
def get(self, request):
- return Response('foo')
+ return Response("foo")
class MockView_MinuteThrottling(APIView):
throttle_classes = (User3MinRateThrottle,)
def get(self, request):
- return Response('foo')
+ return Response("foo")
class MockView_NonTimeThrottling(APIView):
throttle_classes = (NonTimeThrottle,)
def get(self, request):
- return Response('foo')
+ return Response("foo")
class ThrottlingTests(TestCase):
@@ -72,7 +75,7 @@ class ThrottlingTests(TestCase):
"""
Ensure request rate is limited
"""
- request = self.factory.get('/')
+ request = self.factory.get("/")
for dummy in range(4):
response = MockView.as_view()(request)
assert response.status_code == 429
@@ -89,7 +92,7 @@ class ThrottlingTests(TestCase):
"""
self.set_throttle_timer(MockView, 0)
- request = self.factory.get('/')
+ request = self.factory.get("/")
for dummy in range(4):
response = MockView.as_view()(request)
assert response.status_code == 429
@@ -101,11 +104,11 @@ class ThrottlingTests(TestCase):
assert response.status_code == 200
def ensure_is_throttled(self, view, expect):
- request = self.factory.get('/')
- request.user = User.objects.create(username='a')
+ request = self.factory.get("/")
+ request.user = User.objects.create(username="a")
for dummy in range(3):
view.as_view()(request)
- request.user = User.objects.create(username='b')
+ request.user = User.objects.create(username="b")
response = view.as_view()(request)
assert response.status_code == expect
@@ -116,31 +119,28 @@ class ThrottlingTests(TestCase):
"""
self.ensure_is_throttled(MockView, 200)
- def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
+ def ensure_response_header_contains_proper_throttle_field(
+ self, view, expected_headers
+ ):
"""
Ensure the response returns an Retry-After field with status and next attributes
set properly.
"""
- request = self.factory.get('/')
+ request = self.factory.get("/")
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = view.as_view()(request)
if expect is not None:
- assert response['Retry-After'] == expect
+ assert response["Retry-After"] == expect
else:
- assert not'Retry-After' in response
+ assert "Retry-After" not in response
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(
- MockView, (
- (0, None),
- (0, None),
- (0, None),
- (0, '1')
- )
+ MockView, ((0, None), (0, None), (0, None), (0, "1"))
)
def test_minutes_fields(self):
@@ -148,12 +148,7 @@ class ThrottlingTests(TestCase):
Ensure for minute based throttles.
"""
self.ensure_response_header_contains_proper_throttle_field(
- MockView_MinuteThrottling, (
- (0, None),
- (0, None),
- (0, None),
- (0, '60')
- )
+ MockView_MinuteThrottling, ((0, None), (0, None), (0, None), (0, "60"))
)
def test_next_rate_remains_constant_if_followed(self):
@@ -162,30 +157,27 @@ class ThrottlingTests(TestCase):
the throttling rate should stay constant.
"""
self.ensure_response_header_contains_proper_throttle_field(
- MockView_MinuteThrottling, (
- (0, None),
- (20, None),
- (40, None),
- (60, None),
- (80, None)
- )
+ MockView_MinuteThrottling,
+ ((0, None), (20, None), (40, None), (60, None), (80, None)),
)
def test_non_time_throttle(self):
"""
Ensure for second based throttles.
"""
- request = self.factory.get('/')
+ request = self.factory.get("/")
- self.assertFalse(hasattr(MockView_NonTimeThrottling.throttle_classes[0], 'called'))
+ self.assertFalse(
+ hasattr(MockView_NonTimeThrottling.throttle_classes[0], "called")
+ )
response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('Retry-After' in response)
+ self.assertFalse("Retry-After" in response)
self.assertTrue(MockView_NonTimeThrottling.throttle_classes[0].called)
response = MockView_NonTimeThrottling.as_view()(request)
- self.assertFalse('Retry-After' in response)
+ self.assertFalse("Retry-After" in response)
class ScopedRateThrottleTests(TestCase):
@@ -198,30 +190,30 @@ class ScopedRateThrottleTests(TestCase):
class XYScopedRateThrottle(ScopedRateThrottle):
TIMER_SECONDS = 0
- THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
+ THROTTLE_RATES = {"x": "3/min", "y": "1/min"}
def timer(self):
return self.TIMER_SECONDS
class XView(APIView):
throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'x'
+ throttle_scope = "x"
def get(self, request):
- return Response('x')
+ return Response("x")
class YView(APIView):
throttle_classes = (XYScopedRateThrottle,)
- throttle_scope = 'y'
+ throttle_scope = "y"
def get(self, request):
- return Response('y')
+ return Response("y")
class UnscopedView(APIView):
throttle_classes = (XYScopedRateThrottle,)
def get(self, request):
- return Response('y')
+ return Response("y")
self.throttle_class = XYScopedRateThrottle
self.factory = APIRequestFactory()
@@ -233,7 +225,7 @@ class ScopedRateThrottleTests(TestCase):
self.throttle_class.TIMER_SECONDS += seconds
def test_scoped_rate_throttle(self):
- request = self.factory.get('/')
+ request = self.factory.get("/")
# Should be able to hit x view 3 times per minute.
response = self.x_view(request)
@@ -288,7 +280,7 @@ class ScopedRateThrottleTests(TestCase):
assert response.status_code == 429
def test_unscoped_view_not_throttled(self):
- request = self.factory.get('/')
+ request = self.factory.get("/")
for idx in range(10):
self.increment_timer()
@@ -297,22 +289,21 @@ class ScopedRateThrottleTests(TestCase):
def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
class DummyView(object):
- throttle_scope = 'user'
+ throttle_scope = "user"
request = Request(HttpRequest())
- user = User.objects.create(username='test')
+ user = User.objects.create(username="test")
force_authenticate(request, user)
request.user = user
self.throttle.allow_request(request, DummyView())
cache_key = self.throttle.get_cache_key(request, view=DummyView())
- assert cache_key == 'throttle_user_%s' % user.pk
+ assert cache_key == "throttle_user_%s" % user.pk
class XffTestingBase(TestCase):
def setUp(self):
-
class Throttle(ScopedRateThrottle):
- THROTTLE_RATES = {'test_limit': '1/day'}
+ THROTTLE_RATES = {"test_limit": "1/day"}
TIMER_SECONDS = 0
def timer(self):
@@ -320,20 +311,20 @@ class XffTestingBase(TestCase):
class View(APIView):
throttle_classes = (Throttle,)
- throttle_scope = 'test_limit'
+ throttle_scope = "test_limit"
def get(self, request):
- return Response('test_limit')
+ return Response("test_limit")
cache.clear()
self.throttle = Throttle()
self.view = View.as_view()
- self.request = APIRequestFactory().get('/some_uri')
- self.request.META['REMOTE_ADDR'] = '3.3.3.3'
- self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
+ self.request = APIRequestFactory().get("/some_uri")
+ self.request.META["REMOTE_ADDR"] = "3.3.3.3"
+ self.request.META["HTTP_X_FORWARDED_FOR"] = "0.0.0.0, 1.1.1.1, 2.2.2.2"
def config_proxy(self, num_proxies):
- setattr(api_settings, 'NUM_PROXIES', num_proxies)
+ setattr(api_settings, "NUM_PROXIES", num_proxies)
class IdWithXffBasicTests(XffTestingBase):
@@ -351,13 +342,13 @@ class XffSpoofingTests(XffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
self.view(self.request)
- self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
+ self.request.META["HTTP_X_FORWARDED_FOR"] = "4.4.4.4, 5.5.5.5, 2.2.2.2"
assert self.view(self.request).status_code == 429
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2)
self.view(self.request)
- self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
+ self.request.META["HTTP_X_FORWARDED_FOR"] = "4.4.4.4, 1.1.1.1, 2.2.2.2"
assert self.view(self.request).status_code == 429
@@ -365,27 +356,25 @@ class XffUniqueMachinesTest(XffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
self.view(self.request)
- self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
+ self.request.META["HTTP_X_FORWARDED_FOR"] = "0.0.0.0, 1.1.1.1, 7.7.7.7"
assert self.view(self.request).status_code == 200
def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2)
self.view(self.request)
- self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
+ self.request.META["HTTP_X_FORWARDED_FOR"] = "0.0.0.0, 7.7.7.7, 2.2.2.2"
assert self.view(self.request).status_code == 200
class BaseThrottleTests(TestCase):
-
def test_allow_request_raises_not_implemented_error(self):
with pytest.raises(NotImplementedError):
BaseThrottle().allow_request(request={}, view={})
class SimpleRateThrottleTests(TestCase):
-
def setUp(self):
- SimpleRateThrottle.scope = 'anon'
+ SimpleRateThrottle.scope = "anon"
def test_get_rate_raises_error_if_scope_is_missing(self):
throttle = SimpleRateThrottle()
@@ -394,7 +383,7 @@ class SimpleRateThrottleTests(TestCase):
throttle.get_rate()
def test_throttle_raises_error_if_rate_is_missing(self):
- SimpleRateThrottle.scope = 'invalid scope'
+ SimpleRateThrottle.scope = "invalid scope"
with pytest.raises(ImproperlyConfigured):
SimpleRateThrottle()
@@ -411,7 +400,7 @@ class SimpleRateThrottleTests(TestCase):
def test_allow_request_returns_true_if_key_is_none(self):
throttle = SimpleRateThrottle()
- throttle.rate = 'some rate'
+ throttle.rate = "some rate"
throttle.get_cache_key = lambda *args: None
assert throttle.allow_request(request={}, view={}) is True
@@ -434,13 +423,12 @@ class SimpleRateThrottleTests(TestCase):
class AnonRateThrottleTests(TestCase):
-
def setUp(self):
self.throttle = AnonRateThrottle()
def test_authenticated_user_not_affected(self):
request = Request(HttpRequest())
- user = User.objects.create(username='test')
+ user = User.objects.create(username="test")
force_authenticate(request, user)
request.user = user
assert self.throttle.get_cache_key(request, view={}) is None
@@ -448,4 +436,4 @@ class AnonRateThrottleTests(TestCase):
def test_get_cache_key_returns_correct_value(self):
request = Request(HttpRequest())
cache_key = self.throttle.get_cache_key(request, view={})
- assert cache_key == 'throttle_anon_None'
+ assert cache_key == "throttle_anon_None"
diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py
index 59ba395d2..0a6875074 100644
--- a/tests/test_urlpatterns.py
+++ b/tests/test_urlpatterns.py
@@ -11,8 +11,9 @@ from rest_framework.compat import make_url_resolver, path, re_path
from rest_framework.test import APIRequestFactory
from rest_framework.urlpatterns import format_suffix_patterns
+
# A container class for test paths for the test case
-URLTestPath = namedtuple('URLTestPath', ['path', 'args', 'kwargs'])
+URLTestPath = namedtuple("URLTestPath", ["path", "args", "kwargs"])
def dummy_view(request, *args, **kwargs):
@@ -24,13 +25,16 @@ class FormatSuffixTests(TestCase):
Tests `format_suffix_patterns` against different URLPatterns to ensure the
URLs still resolve properly, including any captured parameters.
"""
+
def _resolve_urlpatterns(self, urlpatterns, test_paths, allowed=None):
factory = APIRequestFactory()
try:
urlpatterns = format_suffix_patterns(urlpatterns, allowed=allowed)
except Exception:
- self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
- resolver = make_url_resolver(r'^/', urlpatterns)
+ self.fail(
+ "Failed to apply `format_suffix_patterns` on the supplied urlpatterns"
+ )
+ resolver = make_url_resolver(r"^/", urlpatterns)
for test_path in test_paths:
try:
test_path, expected_resolved = test_path
@@ -39,7 +43,9 @@ class FormatSuffixTests(TestCase):
request = factory.get(test_path.path)
try:
- callback, callback_args, callback_kwargs = resolver.resolve(request.path_info)
+ callback, callback_args, callback_kwargs = resolver.resolve(
+ request.path_info
+ )
except Resolver404:
callback, callback_args, callback_kwargs = (None, None, None)
if expected_resolved:
@@ -56,171 +62,180 @@ class FormatSuffixTests(TestCase):
def _test_trailing_slash(self, urlpatterns):
test_paths = [
- (URLTestPath('/test.api', (), {'format': 'api'}), True),
- (URLTestPath('/test/.api', (), {'format': 'api'}), False),
- (URLTestPath('/test.api/', (), {'format': 'api'}), True),
+ (URLTestPath("/test.api", (), {"format": "api"}), True),
+ (URLTestPath("/test/.api", (), {"format": "api"}), False),
+ (URLTestPath("/test.api/", (), {"format": "api"}), True),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_trailing_slash(self):
- urlpatterns = [
- url(r'^test/$', dummy_view),
- ]
+ urlpatterns = [url(r"^test/$", dummy_view)]
self._test_trailing_slash(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_trailing_slash_django2(self):
- urlpatterns = [
- path('test/', dummy_view),
- ]
+ urlpatterns = [path("test/", dummy_view)]
self._test_trailing_slash(urlpatterns)
def _test_format_suffix(self, urlpatterns):
test_paths = [
- URLTestPath('/test', (), {}),
- URLTestPath('/test.api', (), {'format': 'api'}),
- URLTestPath('/test.asdf', (), {'format': 'asdf'}),
+ URLTestPath("/test", (), {}),
+ URLTestPath("/test.api", (), {"format": "api"}),
+ URLTestPath("/test.asdf", (), {"format": "asdf"}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_format_suffix(self):
- urlpatterns = [
- url(r'^test$', dummy_view),
- ]
+ urlpatterns = [url(r"^test$", dummy_view)]
self._test_format_suffix(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_format_suffix_django2(self):
- urlpatterns = [
- path('test', dummy_view),
- ]
+ urlpatterns = [path("test", dummy_view)]
self._test_format_suffix(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_format_suffix_django2_args(self):
urlpatterns = [
- path('convtest/', dummy_view),
- re_path(r'^retest/(?P[0-9]+)$', dummy_view),
+ path("convtest/", dummy_view),
+ re_path(r"^retest/(?P[0-9]+)$", dummy_view),
]
test_paths = [
- URLTestPath('/convtest/42', (), {'pk': 42}),
- URLTestPath('/convtest/42.api', (), {'pk': 42, 'format': 'api'}),
- URLTestPath('/convtest/42.asdf', (), {'pk': 42, 'format': 'asdf'}),
- URLTestPath('/retest/42', (), {'pk': '42'}),
- URLTestPath('/retest/42.api', (), {'pk': '42', 'format': 'api'}),
- URLTestPath('/retest/42.asdf', (), {'pk': '42', 'format': 'asdf'}),
+ URLTestPath("/convtest/42", (), {"pk": 42}),
+ URLTestPath("/convtest/42.api", (), {"pk": 42, "format": "api"}),
+ URLTestPath("/convtest/42.asdf", (), {"pk": 42, "format": "asdf"}),
+ URLTestPath("/retest/42", (), {"pk": "42"}),
+ URLTestPath("/retest/42.api", (), {"pk": "42", "format": "api"}),
+ URLTestPath("/retest/42.asdf", (), {"pk": "42", "format": "asdf"}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def _test_default_args(self, urlpatterns):
test_paths = [
- URLTestPath('/test', (), {'foo': 'bar', }),
- URLTestPath('/test.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ URLTestPath("/test", (), {"foo": "bar"}),
+ URLTestPath("/test.api", (), {"foo": "bar", "format": "api"}),
+ URLTestPath("/test.asdf", (), {"foo": "bar", "format": "asdf"}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_default_args(self):
- urlpatterns = [
- url(r'^test$', dummy_view, {'foo': 'bar'}),
- ]
+ urlpatterns = [url(r"^test$", dummy_view, {"foo": "bar"})]
self._test_default_args(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_default_args_django2(self):
- urlpatterns = [
- path('test', dummy_view, {'foo': 'bar'}),
- ]
+ urlpatterns = [path("test", dummy_view, {"foo": "bar"})]
self._test_default_args(urlpatterns)
def _test_included_urls(self, urlpatterns):
test_paths = [
- URLTestPath('/test/path', (), {'foo': 'bar', }),
- URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
- URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
+ URLTestPath("/test/path", (), {"foo": "bar"}),
+ URLTestPath("/test/path.api", (), {"foo": "bar", "format": "api"}),
+ URLTestPath("/test/path.asdf", (), {"foo": "bar", "format": "asdf"}),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def test_included_urls(self):
- nested_patterns = [
- url(r'^path$', dummy_view)
- ]
- urlpatterns = [
- url(r'^test/', include(nested_patterns), {'foo': 'bar'}),
- ]
+ nested_patterns = [url(r"^path$", dummy_view)]
+ urlpatterns = [url(r"^test/", include(nested_patterns), {"foo": "bar"})]
self._test_included_urls(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_included_urls_django2(self):
- nested_patterns = [
- path('path', dummy_view)
- ]
- urlpatterns = [
- path('test/', include(nested_patterns), {'foo': 'bar'}),
- ]
+ nested_patterns = [path("path", dummy_view)]
+ urlpatterns = [path("test/", include(nested_patterns), {"foo": "bar"})]
self._test_included_urls(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_included_urls_django2_mixed(self):
- nested_patterns = [
- path('path', dummy_view)
- ]
- urlpatterns = [
- url('^test/', include(nested_patterns), {'foo': 'bar'}),
- ]
+ nested_patterns = [path("path", dummy_view)]
+ urlpatterns = [url("^test/", include(nested_patterns), {"foo": "bar"})]
self._test_included_urls(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_included_urls_django2_mixed_args(self):
nested_patterns = [
- path('path/', dummy_view),
- url('^url/(?P[0-9]+)$', dummy_view)
+ path("path/", dummy_view),
+ url("^url/(?P[0-9]+)$", dummy_view),
]
urlpatterns = [
- url('^purl/(?P[0-9]+)/', include(nested_patterns), {'foo': 'bar'}),
- path('ppath//', include(nested_patterns), {'foo': 'bar'}),
+ url("^purl/(?P[0-9]+)/", include(nested_patterns), {"foo": "bar"}),
+ path("ppath//", include(nested_patterns), {"foo": "bar"}),
]
test_paths = [
# parent url() nesting child path()
- URLTestPath('/purl/87/path/42', (), {'parent': '87', 'child': 42, 'foo': 'bar', }),
- URLTestPath('/purl/87/path/42.api', (), {'parent': '87', 'child': 42, 'foo': 'bar', 'format': 'api'}),
- URLTestPath('/purl/87/path/42.asdf', (), {'parent': '87', 'child': 42, 'foo': 'bar', 'format': 'asdf'}),
-
+ URLTestPath(
+ "/purl/87/path/42", (), {"parent": "87", "child": 42, "foo": "bar"}
+ ),
+ URLTestPath(
+ "/purl/87/path/42.api",
+ (),
+ {"parent": "87", "child": 42, "foo": "bar", "format": "api"},
+ ),
+ URLTestPath(
+ "/purl/87/path/42.asdf",
+ (),
+ {"parent": "87", "child": 42, "foo": "bar", "format": "asdf"},
+ ),
# parent path() nesting child url()
- URLTestPath('/ppath/87/url/42', (), {'parent': 87, 'child': '42', 'foo': 'bar', }),
- URLTestPath('/ppath/87/url/42.api', (), {'parent': 87, 'child': '42', 'foo': 'bar', 'format': 'api'}),
- URLTestPath('/ppath/87/url/42.asdf', (), {'parent': 87, 'child': '42', 'foo': 'bar', 'format': 'asdf'}),
-
+ URLTestPath(
+ "/ppath/87/url/42", (), {"parent": 87, "child": "42", "foo": "bar"}
+ ),
+ URLTestPath(
+ "/ppath/87/url/42.api",
+ (),
+ {"parent": 87, "child": "42", "foo": "bar", "format": "api"},
+ ),
+ URLTestPath(
+ "/ppath/87/url/42.asdf",
+ (),
+ {"parent": 87, "child": "42", "foo": "bar", "format": "asdf"},
+ ),
# parent path() nesting child path()
- URLTestPath('/ppath/87/path/42', (), {'parent': 87, 'child': 42, 'foo': 'bar', }),
- URLTestPath('/ppath/87/path/42.api', (), {'parent': 87, 'child': 42, 'foo': 'bar', 'format': 'api'}),
- URLTestPath('/ppath/87/path/42.asdf', (), {'parent': 87, 'child': 42, 'foo': 'bar', 'format': 'asdf'}),
-
+ URLTestPath(
+ "/ppath/87/path/42", (), {"parent": 87, "child": 42, "foo": "bar"}
+ ),
+ URLTestPath(
+ "/ppath/87/path/42.api",
+ (),
+ {"parent": 87, "child": 42, "foo": "bar", "format": "api"},
+ ),
+ URLTestPath(
+ "/ppath/87/path/42.asdf",
+ (),
+ {"parent": 87, "child": 42, "foo": "bar", "format": "asdf"},
+ ),
# parent url() nesting child url()
- URLTestPath('/purl/87/url/42', (), {'parent': '87', 'child': '42', 'foo': 'bar', }),
- URLTestPath('/purl/87/url/42.api', (), {'parent': '87', 'child': '42', 'foo': 'bar', 'format': 'api'}),
- URLTestPath('/purl/87/url/42.asdf', (), {'parent': '87', 'child': '42', 'foo': 'bar', 'format': 'asdf'}),
+ URLTestPath(
+ "/purl/87/url/42", (), {"parent": "87", "child": "42", "foo": "bar"}
+ ),
+ URLTestPath(
+ "/purl/87/url/42.api",
+ (),
+ {"parent": "87", "child": "42", "foo": "bar", "format": "api"},
+ ),
+ URLTestPath(
+ "/purl/87/url/42.asdf",
+ (),
+ {"parent": "87", "child": "42", "foo": "bar", "format": "asdf"},
+ ),
]
self._resolve_urlpatterns(urlpatterns, test_paths)
def _test_allowed_formats(self, urlpatterns):
- allowed_formats = ['good', 'ugly']
+ allowed_formats = ["good", "ugly"]
test_paths = [
- (URLTestPath('/test.good/', (), {'format': 'good'}), True),
- (URLTestPath('/test.bad', (), {}), False),
- (URLTestPath('/test.ugly', (), {'format': 'ugly'}), True),
+ (URLTestPath("/test.good/", (), {"format": "good"}), True),
+ (URLTestPath("/test.bad", (), {}), False),
+ (URLTestPath("/test.ugly", (), {"format": "ugly"}), True),
]
self._resolve_urlpatterns(urlpatterns, test_paths, allowed=allowed_formats)
def test_allowed_formats(self):
- urlpatterns = [
- url('^test$', dummy_view),
- ]
+ urlpatterns = [url("^test$", dummy_view)]
self._test_allowed_formats(urlpatterns)
- @unittest.skipUnless(path, 'needs Django 2')
+ @unittest.skipUnless(path, "needs Django 2")
def test_allowed_formats_django2(self):
- urlpatterns = [
- path('test', dummy_view),
- ]
+ urlpatterns = [path("test", dummy_view)]
self._test_allowed_formats(urlpatterns)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 28b06b173..3e3f1c12c 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -52,123 +52,125 @@ class ResourceViewSet(ModelViewSet):
def detail_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, name='Custom Name')
+ @action(detail=True, name="Custom Name")
def named_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, suffix='Custom Suffix')
+ @action(detail=True, suffix="Custom Suffix")
def suffixed_action(self, request, *args, **kwargs):
raise NotImplementedError
router = SimpleRouter()
-router.register(r'resources', ResourceViewSet)
+router.register(r"resources", ResourceViewSet)
urlpatterns = [
- url(r'^$', Root.as_view()),
- url(r'^resource/$', ResourceRoot.as_view()),
- url(r'^resource/customname$', CustomNameResourceInstance.as_view()),
- url(r'^resource/(?P[0-9]+)$', ResourceInstance.as_view()),
- url(r'^resource/(?P[0-9]+)/$', NestedResourceRoot.as_view()),
- url(r'^resource/(?P[0-9]+)/(?P[A-Za-z]+)$', NestedResourceInstance.as_view()),
+ url(r"^$", Root.as_view()),
+ url(r"^resource/$", ResourceRoot.as_view()),
+ url(r"^resource/customname$", CustomNameResourceInstance.as_view()),
+ url(r"^resource/(?P[0-9]+)$", ResourceInstance.as_view()),
+ url(r"^resource/(?P[0-9]+)/$", NestedResourceRoot.as_view()),
+ url(
+ r"^resource/(?P[0-9]+)/(?P[A-Za-z]+)$",
+ NestedResourceInstance.as_view(),
+ ),
]
urlpatterns += router.urls
-@override_settings(ROOT_URLCONF='tests.test_utils')
+@override_settings(ROOT_URLCONF="tests.test_utils")
class BreadcrumbTests(TestCase):
"""
Tests the breadcrumb functionality used by the HTML renderer.
"""
+
def test_root_breadcrumbs(self):
- url = '/'
- assert get_breadcrumbs(url) == [('Root', '/')]
+ url = "/"
+ assert get_breadcrumbs(url) == [("Root", "/")]
def test_resource_root_breadcrumbs(self):
- url = '/resource/'
- assert get_breadcrumbs(url) == [
- ('Root', '/'), ('Resource Root', '/resource/')
- ]
+ url = "/resource/"
+ assert get_breadcrumbs(url) == [("Root", "/"), ("Resource Root", "/resource/")]
def test_resource_instance_breadcrumbs(self):
- url = '/resource/123'
+ url = "/resource/123"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123')
+ ("Root", "/"),
+ ("Resource Root", "/resource/"),
+ ("Resource Instance", "/resource/123"),
]
def test_resource_instance_customname_breadcrumbs(self):
- url = '/resource/customname'
+ url = "/resource/customname"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Foo', '/resource/customname')
+ ("Root", "/"),
+ ("Resource Root", "/resource/"),
+ ("Foo", "/resource/customname"),
]
def test_nested_resource_breadcrumbs(self):
- url = '/resource/123/'
+ url = "/resource/123/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/')
+ ("Root", "/"),
+ ("Resource Root", "/resource/"),
+ ("Resource Instance", "/resource/123"),
+ ("Nested Resource Root", "/resource/123/"),
]
def test_nested_resource_instance_breadcrumbs(self):
- url = '/resource/123/abc'
+ url = "/resource/123/abc"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource Root', '/resource/'),
- ('Resource Instance', '/resource/123'),
- ('Nested Resource Root', '/resource/123/'),
- ('Nested Resource Instance', '/resource/123/abc')
+ ("Root", "/"),
+ ("Resource Root", "/resource/"),
+ ("Resource Instance", "/resource/123"),
+ ("Nested Resource Root", "/resource/123/"),
+ ("Nested Resource Instance", "/resource/123/abc"),
]
def test_broken_url_breadcrumbs_handled_gracefully(self):
- url = '/foobar'
- assert get_breadcrumbs(url) == [('Root', '/')]
+ url = "/foobar"
+ assert get_breadcrumbs(url) == [("Root", "/")]
def test_modelviewset_resource_instance_breadcrumbs(self):
- url = '/resources/1/'
+ url = "/resources/1/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource List', '/resources/'),
- ('Resource Instance', '/resources/1/')
+ ("Root", "/"),
+ ("Resource List", "/resources/"),
+ ("Resource Instance", "/resources/1/"),
]
def test_modelviewset_list_action_breadcrumbs(self):
- url = '/resources/list_action/'
+ url = "/resources/list_action/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource List', '/resources/'),
- ('List action', '/resources/list_action/'),
+ ("Root", "/"),
+ ("Resource List", "/resources/"),
+ ("List action", "/resources/list_action/"),
]
def test_modelviewset_detail_action_breadcrumbs(self):
- url = '/resources/1/detail_action/'
+ url = "/resources/1/detail_action/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource List', '/resources/'),
- ('Resource Instance', '/resources/1/'),
- ('Detail action', '/resources/1/detail_action/'),
+ ("Root", "/"),
+ ("Resource List", "/resources/"),
+ ("Resource Instance", "/resources/1/"),
+ ("Detail action", "/resources/1/detail_action/"),
]
def test_modelviewset_action_name_kwarg(self):
- url = '/resources/1/named_action/'
+ url = "/resources/1/named_action/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource List', '/resources/'),
- ('Resource Instance', '/resources/1/'),
- ('Custom Name', '/resources/1/named_action/'),
+ ("Root", "/"),
+ ("Resource List", "/resources/"),
+ ("Resource Instance", "/resources/1/"),
+ ("Custom Name", "/resources/1/named_action/"),
]
def test_modelviewset_action_suffix_kwarg(self):
- url = '/resources/1/suffixed_action/'
+ url = "/resources/1/suffixed_action/"
assert get_breadcrumbs(url) == [
- ('Root', '/'),
- ('Resource List', '/resources/'),
- ('Resource Instance', '/resources/1/'),
- ('Resource Custom Suffix', '/resources/1/suffixed_action/'),
+ ("Root", "/"),
+ ("Resource List", "/resources/"),
+ ("Resource Instance", "/resources/1/"),
+ ("Resource Custom Suffix", "/resources/1/suffixed_action/"),
]
@@ -179,10 +181,10 @@ class JsonFloatTests(TestCase):
def test_dumps(self):
with self.assertRaises(ValueError):
- json.dumps(float('inf'))
+ json.dumps(float("inf"))
with self.assertRaises(ValueError):
- json.dumps(float('nan'))
+ json.dumps(float("nan"))
def test_loads(self):
with self.assertRaises(ValueError):
@@ -203,30 +205,31 @@ class UrlsReplaceQueryParamTests(TestCase):
"""
Tests the replace_query_param functionality.
"""
+
def test_valid_unicode_preserved(self):
# Encoded string: '查询'
- q = '/?q=%E6%9F%A5%E8%AF%A2'
- new_key = 'page'
+ q = "/?q=%E6%9F%A5%E8%AF%A2"
+ new_key = "page"
new_value = 2
- value = '%E6%9F%A5%E8%AF%A2'
+ value = "%E6%9F%A5%E8%AF%A2"
assert new_key in replace_query_param(q, new_key, new_value)
assert value in replace_query_param(q, new_key, new_value)
def test_valid_unicode_replaced(self):
- q = '/?page=1'
- value = '1'
- new_key = 'q'
- new_value = '%E6%9F%A5%E8%AF%A2'
+ q = "/?page=1"
+ value = "1"
+ new_key = "q"
+ new_value = "%E6%9F%A5%E8%AF%A2"
assert new_key in replace_query_param(q, new_key, new_value)
assert value in replace_query_param(q, new_key, new_value)
def test_invalid_unicode(self):
# Encoded string: '��=1'
- q = '/e/?%FF%FE%3C%73%63%72%69%70%74%3E%61%6C%65%72%74%28%33%31%33%29%3C%2F%73%63%72%69%70%74%3E=1'
- key = 'from'
- value = 'login'
+ q = "/e/?%FF%FE%3C%73%63%72%69%70%74%3E%61%6C%65%72%74%28%33%31%33%29%3C%2F%73%63%72%69%70%74%3E=1"
+ key = "from"
+ value = "login"
assert key in replace_query_param(q, key, value)
@@ -235,28 +238,29 @@ class UrlsRemoveQueryParamTests(TestCase):
"""
Tests the remove_query_param functionality.
"""
+
def test_valid_unicode_preserved(self):
- q = '/?q=%E6%9F%A5%E8%AF%A2'
- new_key = 'page'
+ q = "/?q=%E6%9F%A5%E8%AF%A2"
+ new_key = "page"
new_value = 2
- value = '%E6%9F%A5%E8%AF%A2'
+ value = "%E6%9F%A5%E8%AF%A2"
assert new_key in replace_query_param(q, new_key, new_value)
assert value in replace_query_param(q, new_key, new_value)
def test_valid_unicode_removed(self):
- q = '/?page=2345&q=%E6%9F%A5%E8%AF%A2'
- key = 'page'
- value = '2345'
- removed_key = 'q'
+ q = "/?page=2345&q=%E6%9F%A5%E8%AF%A2"
+ key = "page"
+ value = "2345"
+ removed_key = "q"
assert key in remove_query_param(q, removed_key)
assert value in remove_query_param(q, removed_key)
- assert '%' not in remove_query_param(q, removed_key)
+ assert "%" not in remove_query_param(q, removed_key)
def test_invalid_unicode(self):
- q = '/?from=login&page=2&%FF%FE%3C%73%63%72%69%70%74%3E%61%6C%65%72%74%28%33%31%33%29%3C%2F%73%63%72%69%70%74%3E=1'
- key = 'from'
- removed_key = 'page'
+ q = "/?from=login&page=2&%FF%FE%3C%73%63%72%69%70%74%3E%61%6C%65%72%74%28%33%31%33%29%3C%2F%73%63%72%69%70%74%3E=1"
+ key = "from"
+ removed_key = "page"
assert key in remove_query_param(q, removed_key)
diff --git a/tests/test_validation.py b/tests/test_validation.py
index 4132a7b00..aa60e488a 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -10,11 +10,13 @@ from django.utils import six
from rest_framework import generics, serializers, status
from rest_framework.test import APIRequestFactory
+
factory = APIRequestFactory()
# Regression for #666
+
class ValidationModel(models.Model):
blank_validated_field = models.CharField(max_length=255)
@@ -22,8 +24,8 @@ class ValidationModel(models.Model):
class ValidationModelSerializer(serializers.ModelSerializer):
class Meta:
model = ValidationModel
- fields = ('blank_validated_field',)
- read_only_fields = ('blank_validated_field',)
+ fields = ("blank_validated_field",)
+ read_only_fields = ("blank_validated_field",)
class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
@@ -33,21 +35,22 @@ class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
# Regression for #653
+
class ShouldValidateModel(models.Model):
should_validate_field = models.CharField(max_length=255)
class ShouldValidateModelSerializer(serializers.ModelSerializer):
- renamed = serializers.CharField(source='should_validate_field', required=False)
+ renamed = serializers.CharField(source="should_validate_field", required=False)
def validate_renamed(self, value):
if len(value) < 3:
- raise serializers.ValidationError('Minimum 3 characters.')
+ raise serializers.ValidationError("Minimum 3 characters.")
return value
class Meta:
model = ShouldValidateModel
- fields = ('renamed',)
+ fields = ("renamed",)
class TestNestedValidationError(TestCase):
@@ -55,17 +58,9 @@ class TestNestedValidationError(TestCase):
"""
Ensure nested validation error detail is rendered correctly.
"""
- e = serializers.ValidationError({
- 'nested': {
- 'field': ['error'],
- }
- })
+ e = serializers.ValidationError({"nested": {"field": ["error"]}})
- assert serializers.as_serializer_error(e) == {
- 'nested': {
- 'field': ['error'],
- }
- }
+ assert serializers.as_serializer_error(e) == {"nested": {"field": ["error"]}}
class TestPreSaveValidationExclusionsSerializer(TestCase):
@@ -75,20 +70,20 @@ class TestPreSaveValidationExclusionsSerializer(TestCase):
"""
# We've set `required=False` on the serializer, but the model
# does not have `blank=True`, so this serializer should not validate.
- serializer = ShouldValidateModelSerializer(data={'renamed': ''})
+ serializer = ShouldValidateModelSerializer(data={"renamed": ""})
assert serializer.is_valid() is False
- assert 'renamed' in serializer.errors
- assert 'should_validate_field' not in serializer.errors
+ assert "renamed" in serializer.errors
+ assert "should_validate_field" not in serializer.errors
class TestCustomValidationMethods(TestCase):
def test_custom_validation_method_is_executed(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'fo'})
+ serializer = ShouldValidateModelSerializer(data={"renamed": "fo"})
assert not serializer.is_valid()
- assert 'renamed' in serializer.errors
+ assert "renamed" in serializer.errors
def test_custom_validation_method_passing(self):
- serializer = ShouldValidateModelSerializer(data={'renamed': 'foo'})
+ serializer = ShouldValidateModelSerializer(data={"renamed": "foo"})
assert serializer.is_valid()
@@ -107,18 +102,21 @@ class TestAvoidValidation(TestCase):
If serializer was initialized with invalid data (None or non dict-like), it
should avoid validation layer (validate_ and validate methods)
"""
+
def test_serializer_errors_has_only_invalid_data_error(self):
- serializer = ValidationSerializer(data='invalid data')
+ serializer = ValidationSerializer(data="invalid data")
assert not serializer.is_valid()
assert serializer.errors == {
- 'non_field_errors': [
- 'Invalid data. Expected a dictionary, but got %s.' % six.text_type.__name__
+ "non_field_errors": [
+ "Invalid data. Expected a dictionary, but got %s."
+ % six.text_type.__name__
]
}
# regression tests for issue: 1493
+
class ValidationMaxValueValidatorModel(models.Model):
number_value = models.PositiveIntegerField(validators=[MaxValueValidator(100)])
@@ -126,7 +124,7 @@ class ValidationMaxValueValidatorModel(models.Model):
class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
class Meta:
model = ValidationMaxValueValidatorModel
- fields = '__all__'
+ fields = "__all__"
class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
@@ -135,64 +133,58 @@ class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
class TestMaxValueValidatorValidation(TestCase):
-
def test_max_value_validation_serializer_success(self):
- serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 99})
+ serializer = ValidationMaxValueValidatorModelSerializer(
+ data={"number_value": 99}
+ )
assert serializer.is_valid()
def test_max_value_validation_serializer_fails(self):
- serializer = ValidationMaxValueValidatorModelSerializer(data={'number_value': 101})
+ serializer = ValidationMaxValueValidatorModelSerializer(
+ data={"number_value": 101}
+ )
assert not serializer.is_valid()
assert serializer.errors == {
- 'number_value': [
- 'Ensure this value is less than or equal to 100.'
- ]
+ "number_value": ["Ensure this value is less than or equal to 100."]
}
def test_max_value_validation_success(self):
obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
- request = factory.patch('/{0}'.format(obj.pk), {'number_value': 98}, format='json')
+ request = factory.patch(
+ "/{0}".format(obj.pk), {"number_value": 98}, format="json"
+ )
view = UpdateMaxValueValidationModel().as_view()
response = view(request, pk=obj.pk).render()
assert response.status_code == status.HTTP_200_OK
def test_max_value_validation_fail(self):
obj = ValidationMaxValueValidatorModel.objects.create(number_value=100)
- request = factory.patch('/{0}'.format(obj.pk), {'number_value': 101}, format='json')
+ request = factory.patch(
+ "/{0}".format(obj.pk), {"number_value": 101}, format="json"
+ )
view = UpdateMaxValueValidationModel().as_view()
response = view(request, pk=obj.pk).render()
- assert response.content == b'{"number_value":["Ensure this value is less than or equal to 100."]}'
+ assert (
+ response.content
+ == b'{"number_value":["Ensure this value is less than or equal to 100."]}'
+ )
assert response.status_code == status.HTTP_400_BAD_REQUEST
# regression tests for issue: 1533
+
class TestChoiceFieldChoicesValidate(TestCase):
- CHOICES = [
- (0, 'Small'),
- (1, 'Medium'),
- (2, 'Large'),
- ]
+ CHOICES = [(0, "Small"), (1, "Medium"), (2, "Large")]
SINGLE_CHOICES = [0, 1, 2]
CHOICES_NESTED = [
- ('Category', (
- (1, 'First'),
- (2, 'Second'),
- (3, 'Third'),
- )),
- (4, 'Fourth'),
+ ("Category", ((1, "First"), (2, "Second"), (3, "Third"))),
+ (4, "Fourth"),
]
- MIXED_CHOICES = [
- ('Category', (
- (1, 'First'),
- (2, 'Second'),
- )),
- 3,
- (4, 'Fourth'),
- ]
+ MIXED_CHOICES = [("Category", ((1, "First"), (2, "Second"))), 3, (4, "Fourth")]
def test_choices(self):
"""
@@ -241,8 +233,12 @@ class TestChoiceFieldChoicesValidate(TestCase):
class RegexSerializer(serializers.Serializer):
pin = serializers.CharField(
- validators=[RegexValidator(regex=re.compile('^[0-9]{4,6}$'),
- message='A PIN is 4-6 digits')])
+ validators=[
+ RegexValidator(
+ regex=re.compile("^[0-9]{4,6}$"), message="A PIN is 4-6 digits"
+ )
+ ]
+ )
expected_repr = """
diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py
index 562fe37e6..f68bfcbd0 100644
--- a/tests/test_validation_error.py
+++ b/tests/test_validation_error.py
@@ -7,6 +7,7 @@ from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
+
factory = APIRequestFactory()
@@ -20,7 +21,7 @@ class ErrorView(APIView):
ExampleSerializer(data={}).is_valid(raise_exception=True)
-@api_view(['GET'])
+@api_view(["GET"])
def error_view(request):
ExampleSerializer(data={}).is_valid(raise_exception=True)
@@ -36,14 +37,8 @@ class TestValidationErrorWithFullDetails(TestCase):
api_settings.EXCEPTION_HANDLER = exception_handler
self.expected_response_data = {
- 'char': [{
- 'message': 'This field is required.',
- 'code': 'required',
- }],
- 'integer': [{
- 'message': 'This field is required.',
- 'code': 'required'
- }],
+ "char": [{"message": "This field is required.", "code": "required"}],
+ "integer": [{"message": "This field is required.", "code": "required"}],
}
def tearDown(self):
@@ -52,7 +47,7 @@ class TestValidationErrorWithFullDetails(TestCase):
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == self.expected_response_data
@@ -60,7 +55,7 @@ class TestValidationErrorWithFullDetails(TestCase):
def test_function_based_view_exception_handler(self):
view = error_view
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == self.expected_response_data
@@ -76,10 +71,7 @@ class TestValidationErrorWithCodes(TestCase):
api_settings.EXCEPTION_HANDLER = exception_handler
- self.expected_response_data = {
- 'char': ['required'],
- 'integer': ['required'],
- }
+ self.expected_response_data = {"char": ["required"], "integer": ["required"]}
def tearDown(self):
api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER
@@ -87,7 +79,7 @@ class TestValidationErrorWithCodes(TestCase):
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == self.expected_response_data
@@ -95,7 +87,7 @@ class TestValidationErrorWithCodes(TestCase):
def test_function_based_view_exception_handler(self):
view = error_view
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == self.expected_response_data
diff --git a/tests/test_validators.py b/tests/test_validators.py
index 4bbddb64b..08962da67 100644
--- a/tests/test_validators.py
+++ b/tests/test_validators.py
@@ -7,18 +7,21 @@ from django.test import TestCase
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.validators import (
- BaseUniqueForValidator, UniqueTogetherValidator, UniqueValidator,
- qs_exists
+ BaseUniqueForValidator,
+ UniqueTogetherValidator,
+ UniqueValidator,
+ qs_exists,
)
def dedent(blocktext):
- return '\n'.join([line[12:] for line in blocktext.splitlines()[1:-1]])
+ return "\n".join([line[12:] for line in blocktext.splitlines()[1:-1]])
# Tests for `UniqueValidator`
# ---------------------------
+
class UniquenessModel(models.Model):
username = models.CharField(unique=True, max_length=100)
@@ -26,7 +29,7 @@ class UniquenessModel(models.Model):
class UniquenessSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessModel
- fields = '__all__'
+ fields = "__all__"
class RelatedModel(models.Model):
@@ -35,12 +38,16 @@ class RelatedModel(models.Model):
class RelatedModelSerializer(serializers.ModelSerializer):
- username = serializers.CharField(source='user.username',
- validators=[UniqueValidator(queryset=UniquenessModel.objects.all(), lookup='iexact')]) # NOQA
+ username = serializers.CharField(
+ source="user.username",
+ validators=[
+ UniqueValidator(queryset=UniquenessModel.objects.all(), lookup="iexact")
+ ],
+ ) # NOQA
class Meta:
model = RelatedModel
- fields = ('username', 'email')
+ fields = ("username", "email")
class AnotherUniquenessModel(models.Model):
@@ -50,7 +57,7 @@ class AnotherUniquenessModel(models.Model):
class AnotherUniquenessSerializer(serializers.ModelSerializer):
class Meta:
model = AnotherUniquenessModel
- fields = '__all__'
+ fields = "__all__"
class IntegerFieldModel(models.Model):
@@ -62,72 +69,79 @@ class UniquenessIntegerSerializer(serializers.Serializer):
# This allows us to ensure that `ValueError`, `TypeError` or `DataError` etc
# raised by a uniqueness check does not trigger a deceptive "this field is not unique"
# validation failure.
- integer = serializers.CharField(validators=[UniqueValidator(queryset=IntegerFieldModel.objects.all())])
+ integer = serializers.CharField(
+ validators=[UniqueValidator(queryset=IntegerFieldModel.objects.all())]
+ )
class TestUniquenessValidation(TestCase):
def setUp(self):
- self.instance = UniquenessModel.objects.create(username='existing')
+ self.instance = UniquenessModel.objects.create(username="existing")
def test_repr(self):
serializer = UniquenessSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
UniquenessSerializer():
id = IntegerField(label='ID', read_only=True)
username = CharField(max_length=100, validators=[])
- """)
+ """
+ )
assert repr(serializer) == expected
def test_is_not_unique(self):
- data = {'username': 'existing'}
+ data = {"username": "existing"}
serializer = UniquenessSerializer(data=data)
assert not serializer.is_valid()
- assert serializer.errors == {'username': ['uniqueness model with this username already exists.']}
+ assert serializer.errors == {
+ "username": ["uniqueness model with this username already exists."]
+ }
def test_is_unique(self):
- data = {'username': 'other'}
+ data = {"username": "other"}
serializer = UniquenessSerializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'username': 'other'}
+ assert serializer.validated_data == {"username": "other"}
def test_updated_instance_excluded(self):
- data = {'username': 'existing'}
+ data = {"username": "existing"}
serializer = UniquenessSerializer(self.instance, data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {'username': 'existing'}
+ assert serializer.validated_data == {"username": "existing"}
def test_doesnt_pollute_model(self):
- instance = AnotherUniquenessModel.objects.create(code='100')
+ instance = AnotherUniquenessModel.objects.create(code="100")
serializer = AnotherUniquenessSerializer(instance)
- assert AnotherUniquenessModel._meta.get_field('code').validators == []
+ assert AnotherUniquenessModel._meta.get_field("code").validators == []
# Accessing data shouldn't effect validators on the model
serializer.data
- assert AnotherUniquenessModel._meta.get_field('code').validators == []
+ assert AnotherUniquenessModel._meta.get_field("code").validators == []
def test_related_model_is_unique(self):
- data = {'username': 'Existing', 'email': 'new-email@example.com'}
+ data = {"username": "Existing", "email": "new-email@example.com"}
rs = RelatedModelSerializer(data=data)
assert not rs.is_valid()
- assert rs.errors == {'username': ['This field must be unique.']}
- data = {'username': 'new-username', 'email': 'new-email@example.com'}
+ assert rs.errors == {"username": ["This field must be unique."]}
+ data = {"username": "new-username", "email": "new-email@example.com"}
rs = RelatedModelSerializer(data=data)
assert rs.is_valid()
def test_value_error_treated_as_not_unique(self):
- serializer = UniquenessIntegerSerializer(data={'integer': 'abc'})
+ serializer = UniquenessIntegerSerializer(data={"integer": "abc"})
assert serializer.is_valid()
# Tests for `UniqueTogetherValidator`
# -----------------------------------
+
class UniquenessTogetherModel(models.Model):
race_name = models.CharField(max_length=100)
position = models.IntegerField()
class Meta:
- unique_together = ('race_name', 'position')
+ unique_together = ("race_name", "position")
class NullUniquenessTogetherModel(models.Model):
@@ -143,63 +157,59 @@ class NullUniquenessTogetherModel(models.Model):
there could be many non-finishers in a race, but all non-NULL
values *should* be unique against the given `race_name`.
"""
+
date_of_birth = models.DateField(null=True) # Not part of the uniqueness constraint
race_name = models.CharField(max_length=100)
position = models.IntegerField(null=True)
class Meta:
- unique_together = ('race_name', 'position')
+ unique_together = ("race_name", "position")
class UniquenessTogetherSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
- fields = '__all__'
+ fields = "__all__"
class NullUniquenessTogetherSerializer(serializers.ModelSerializer):
class Meta:
model = NullUniquenessTogetherModel
- fields = '__all__'
+ fields = "__all__"
class TestUniquenessTogetherValidation(TestCase):
def setUp(self):
self.instance = UniquenessTogetherModel.objects.create(
- race_name='example',
- position=1
- )
- UniquenessTogetherModel.objects.create(
- race_name='example',
- position=2
- )
- UniquenessTogetherModel.objects.create(
- race_name='other',
- position=1
+ race_name="example", position=1
)
+ UniquenessTogetherModel.objects.create(race_name="example", position=2)
+ UniquenessTogetherModel.objects.create(race_name="other", position=1)
def test_repr(self):
serializer = UniquenessTogetherSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
UniquenessTogetherSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100, required=True)
position = IntegerField(required=True)
class Meta:
validators = []
- """)
+ """
+ )
assert repr(serializer) == expected
def test_is_not_unique_together(self):
"""
Failing unique together validation should result in non field errors.
"""
- data = {'race_name': 'example', 'position': 2}
+ data = {"race_name": "example", "position": 2}
serializer = UniquenessTogetherSerializer(data=data)
assert not serializer.is_valid()
assert serializer.errors == {
- 'non_field_errors': [
- 'The fields race_name, position must make a unique set.'
+ "non_field_errors": [
+ "The fields race_name, position must make a unique set."
]
}
@@ -208,53 +218,49 @@ class TestUniquenessTogetherValidation(TestCase):
In a unique together validation, one field may be non-unique
so long as the set as a whole is unique.
"""
- data = {'race_name': 'other', 'position': 2}
+ data = {"race_name": "other", "position": 2}
serializer = UniquenessTogetherSerializer(data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {
- 'race_name': 'other',
- 'position': 2
- }
+ assert serializer.validated_data == {"race_name": "other", "position": 2}
def test_updated_instance_excluded_from_unique_together(self):
"""
When performing an update, the existing instance does not count
as a match against uniqueness.
"""
- data = {'race_name': 'example', 'position': 1}
+ data = {"race_name": "example", "position": 1}
serializer = UniquenessTogetherSerializer(self.instance, data=data)
assert serializer.is_valid()
- assert serializer.validated_data == {
- 'race_name': 'example',
- 'position': 1
- }
+ assert serializer.validated_data == {"race_name": "example", "position": 1}
def test_unique_together_is_required(self):
"""
In a unique together validation, all fields are required.
"""
- data = {'position': 2}
+ data = {"position": 2}
serializer = UniquenessTogetherSerializer(data=data, partial=True)
assert not serializer.is_valid()
- assert serializer.errors == {
- 'race_name': ['This field is required.']
- }
+ assert serializer.errors == {"race_name": ["This field is required."]}
def test_ignore_excluded_fields(self):
"""
When model fields are not included in a serializer, then uniqueness
validators should not be added for that field.
"""
+
class ExcludedFieldSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
- fields = ('id', 'race_name',)
+ fields = ("id", "race_name")
+
serializer = ExcludedFieldSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
ExcludedFieldSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100)
- """)
+ """
+ )
assert repr(serializer) == expected
def test_ignore_read_only_fields(self):
@@ -262,42 +268,48 @@ class TestUniquenessTogetherValidation(TestCase):
When serializer fields are read only, then uniqueness
validators should not be added for that field.
"""
+
class ReadOnlyFieldSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
- fields = ('id', 'race_name', 'position')
- read_only_fields = ('race_name',)
+ fields = ("id", "race_name", "position")
+ read_only_fields = ("race_name",)
serializer = ReadOnlyFieldSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
ReadOnlyFieldSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(read_only=True)
position = IntegerField(required=True)
- """)
+ """
+ )
assert repr(serializer) == expected
def test_read_only_fields_with_default(self):
"""
Special case of read_only + default DOES validate unique_together.
"""
+
class ReadOnlyFieldWithDefaultSerializer(serializers.ModelSerializer):
- race_name = serializers.CharField(max_length=100, read_only=True, default='example')
+ race_name = serializers.CharField(
+ max_length=100, read_only=True, default="example"
+ )
class Meta:
model = UniquenessTogetherModel
- fields = ('id', 'race_name', 'position')
+ fields = ("id", "race_name", "position")
- data = {'position': 2}
+ data = {"position": 2}
serializer = ReadOnlyFieldWithDefaultSerializer(data=data)
assert len(serializer.validators) == 1
assert isinstance(serializer.validators[0], UniqueTogetherValidator)
- assert serializer.validators[0].fields == ('race_name', 'position')
+ assert serializer.validators[0].fields == ("race_name", "position")
assert not serializer.is_valid()
assert serializer.errors == {
- 'non_field_errors': [
- 'The fields race_name, position must make a unique set.'
+ "non_field_errors": [
+ "The fields race_name, position must make a unique set."
]
}
@@ -305,19 +317,22 @@ class TestUniquenessTogetherValidation(TestCase):
"""
Ensure validators can be explicitly removed..
"""
+
class NoValidatorsSerializer(serializers.ModelSerializer):
class Meta:
model = UniquenessTogetherModel
- fields = ('id', 'race_name', 'position')
+ fields = ("id", "race_name", "position")
validators = []
serializer = NoValidatorsSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
NoValidatorsSerializer():
id = IntegerField(label='ID', read_only=True)
race_name = CharField(max_length=100)
position = IntegerField()
- """)
+ """
+ )
assert repr(serializer) == expected
def test_ignore_validation_for_null_fields(self):
@@ -325,13 +340,13 @@ class TestUniquenessTogetherValidation(TestCase):
# constraint cause the instance to ignore uniqueness validation.
NullUniquenessTogetherModel.objects.create(
date_of_birth=datetime.date(2000, 1, 1),
- race_name='Paris Marathon',
- position=None
+ race_name="Paris Marathon",
+ position=None,
)
data = {
- 'date': datetime.date(2000, 1, 1),
- 'race_name': 'Paris Marathon',
- 'position': None
+ "date": datetime.date(2000, 1, 1),
+ "race_name": "Paris Marathon",
+ "position": None,
}
serializer = NullUniquenessTogetherSerializer(data=data)
assert serializer.is_valid()
@@ -341,10 +356,10 @@ class TestUniquenessTogetherValidation(TestCase):
# do not cause the instance to skip validation.
NullUniquenessTogetherModel.objects.create(
date_of_birth=datetime.date(2000, 1, 1),
- race_name='Paris Marathon',
- position=1
+ race_name="Paris Marathon",
+ position=1,
)
- data = {'date': None, 'race_name': 'Paris Marathon', 'position': 1}
+ data = {"date": None, "race_name": "Paris Marathon", "position": 1}
serializer = NullUniquenessTogetherSerializer(data=data)
assert not serializer.is_valid()
@@ -353,73 +368,75 @@ class TestUniquenessTogetherValidation(TestCase):
filter_queryset should add value from existing instance attribute
if it is not provided in attributes dict
"""
+
class MockQueryset(object):
def filter(self, **kwargs):
self.called_with = kwargs
- data = {'race_name': 'bar'}
+ data = {"race_name": "bar"}
queryset = MockQueryset()
- validator = UniqueTogetherValidator(queryset, fields=('race_name',
- 'position'))
+ validator = UniqueTogetherValidator(queryset, fields=("race_name", "position"))
validator.instance = self.instance
validator.filter_queryset(attrs=data, queryset=queryset)
- assert queryset.called_with == {'race_name': 'bar', 'position': 1}
+ assert queryset.called_with == {"race_name": "bar", "position": 1}
# Tests for `UniqueForDateValidator`
# ----------------------------------
+
class UniqueForDateModel(models.Model):
- slug = models.CharField(max_length=100, unique_for_date='published')
+ slug = models.CharField(max_length=100, unique_for_date="published")
published = models.DateField()
class UniqueForDateSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueForDateModel
- fields = '__all__'
+ fields = "__all__"
class TestUniquenessForDateValidation(TestCase):
def setUp(self):
self.instance = UniqueForDateModel.objects.create(
- slug='existing',
- published='2000-01-01'
+ slug="existing", published="2000-01-01"
)
def test_repr(self):
serializer = UniqueForDateSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
UniqueForDateSerializer():
id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100)
published = DateField(required=True)
class Meta:
validators = []
- """)
+ """
+ )
assert repr(serializer) == expected
def test_is_not_unique_for_date(self):
"""
Failing unique for date validation should result in field error.
"""
- data = {'slug': 'existing', 'published': '2000-01-01'}
+ data = {"slug": "existing", "published": "2000-01-01"}
serializer = UniqueForDateSerializer(data=data)
assert not serializer.is_valid()
assert serializer.errors == {
- 'slug': ['This field must be unique for the "published" date.']
+ "slug": ['This field must be unique for the "published" date.']
}
def test_is_unique_for_date(self):
"""
Passing unique for date validation.
"""
- data = {'slug': 'existing', 'published': '2000-01-02'}
+ data = {"slug": "existing", "published": "2000-01-02"}
serializer = UniqueForDateSerializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
- 'slug': 'existing',
- 'published': datetime.date(2000, 1, 2)
+ "slug": "existing",
+ "published": datetime.date(2000, 1, 2),
}
def test_updated_instance_excluded_from_unique_for_date(self):
@@ -427,95 +444,95 @@ class TestUniquenessForDateValidation(TestCase):
When performing an update, the existing instance does not count
as a match against unique_for_date.
"""
- data = {'slug': 'existing', 'published': '2000-01-01'}
+ data = {"slug": "existing", "published": "2000-01-01"}
serializer = UniqueForDateSerializer(instance=self.instance, data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
- 'slug': 'existing',
- 'published': datetime.date(2000, 1, 1)
+ "slug": "existing",
+ "published": datetime.date(2000, 1, 1),
}
+
# Tests for `UniqueForMonthValidator`
# ----------------------------------
class UniqueForMonthModel(models.Model):
- slug = models.CharField(max_length=100, unique_for_month='published')
+ slug = models.CharField(max_length=100, unique_for_month="published")
published = models.DateField()
class UniqueForMonthSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueForMonthModel
- fields = '__all__'
+ fields = "__all__"
class UniqueForMonthTests(TestCase):
-
def setUp(self):
self.instance = UniqueForMonthModel.objects.create(
- slug='existing', published='2017-01-01'
+ slug="existing", published="2017-01-01"
)
def test_not_unique_for_month(self):
- data = {'slug': 'existing', 'published': '2017-01-01'}
+ data = {"slug": "existing", "published": "2017-01-01"}
serializer = UniqueForMonthSerializer(data=data)
assert not serializer.is_valid()
assert serializer.errors == {
- 'slug': ['This field must be unique for the "published" month.']
+ "slug": ['This field must be unique for the "published" month.']
}
def test_unique_for_month(self):
- data = {'slug': 'existing', 'published': '2017-02-01'}
+ data = {"slug": "existing", "published": "2017-02-01"}
serializer = UniqueForMonthSerializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
- 'slug': 'existing',
- 'published': datetime.date(2017, 2, 1)
+ "slug": "existing",
+ "published": datetime.date(2017, 2, 1),
}
+
# Tests for `UniqueForYearValidator`
# ----------------------------------
class UniqueForYearModel(models.Model):
- slug = models.CharField(max_length=100, unique_for_year='published')
+ slug = models.CharField(max_length=100, unique_for_year="published")
published = models.DateField()
class UniqueForYearSerializer(serializers.ModelSerializer):
class Meta:
model = UniqueForYearModel
- fields = '__all__'
+ fields = "__all__"
class UniqueForYearTests(TestCase):
-
def setUp(self):
self.instance = UniqueForYearModel.objects.create(
- slug='existing', published='2017-01-01'
+ slug="existing", published="2017-01-01"
)
def test_not_unique_for_year(self):
- data = {'slug': 'existing', 'published': '2017-01-01'}
+ data = {"slug": "existing", "published": "2017-01-01"}
serializer = UniqueForYearSerializer(data=data)
assert not serializer.is_valid()
assert serializer.errors == {
- 'slug': ['This field must be unique for the "published" year.']
+ "slug": ['This field must be unique for the "published" year.']
}
def test_unique_for_year(self):
- data = {'slug': 'existing', 'published': '2018-01-01'}
+ data = {"slug": "existing", "published": "2018-01-01"}
serializer = UniqueForYearSerializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == {
- 'slug': 'existing',
- 'published': datetime.date(2018, 1, 1)
+ "slug": "existing",
+ "published": datetime.date(2018, 1, 1),
}
class HiddenFieldUniqueForDateModel(models.Model):
- slug = models.CharField(max_length=100, unique_for_date='published')
+ slug = models.CharField(max_length=100, unique_for_date="published")
published = models.DateTimeField(auto_now_add=True)
@@ -524,66 +541,74 @@ class TestHiddenFieldUniquenessForDateValidation(TestCase):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = HiddenFieldUniqueForDateModel
- fields = ('id', 'slug')
+ fields = ("id", "slug")
serializer = TestSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100)
published = HiddenField(default=CreateOnlyDefault())
class Meta:
validators = []
- """)
+ """
+ )
assert repr(serializer) == expected
def test_repr_date_field_included(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = HiddenFieldUniqueForDateModel
- fields = ('id', 'slug', 'published')
+ fields = ("id", "slug", "published")
serializer = TestSerializer()
- expected = dedent("""
+ expected = dedent(
+ """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
slug = CharField(max_length=100)
published = DateTimeField(default=CreateOnlyDefault(), read_only=True)
class Meta:
validators = []
- """)
+ """
+ )
assert repr(serializer) == expected
class ValidatorsTests(TestCase):
-
def test_qs_exists_handles_type_error(self):
class TypeErrorQueryset(object):
def exists(self):
raise TypeError
+
assert qs_exists(TypeErrorQueryset()) is False
def test_qs_exists_handles_value_error(self):
class ValueErrorQueryset(object):
def exists(self):
raise ValueError
+
assert qs_exists(ValueErrorQueryset()) is False
def test_qs_exists_handles_data_error(self):
class DataErrorQueryset(object):
def exists(self):
raise DataError
+
assert qs_exists(DataErrorQueryset()) is False
def test_validator_raises_error_if_not_all_fields_are_provided(self):
- validator = BaseUniqueForValidator(queryset=object(), field='foo',
- date_field='bar')
- attrs = {'foo': 'baz'}
+ validator = BaseUniqueForValidator(
+ queryset=object(), field="foo", date_field="bar"
+ )
+ attrs = {"foo": "baz"}
with pytest.raises(ValidationError):
validator.enforce_required_fields(attrs)
def test_validator_raises_error_when_abstract_method_called(self):
- validator = BaseUniqueForValidator(queryset=object(), field='foo',
- date_field='bar')
+ validator = BaseUniqueForValidator(
+ queryset=object(), field="foo", date_field="bar"
+ )
with pytest.raises(NotImplementedError):
validator.filter_queryset(attrs=None, queryset=None)
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
index 7e650e275..b78bcaa66 100644
--- a/tests/test_versioning.py
+++ b/tests/test_versioning.py
@@ -7,49 +7,47 @@ from rest_framework.decorators import APIView
from rest_framework.relations import PKOnlyObject
from rest_framework.response import Response
from rest_framework.reverse import reverse
-from rest_framework.test import (
- APIRequestFactory, APITestCase, URLPatternsTestCase
-)
+from rest_framework.test import APIRequestFactory, APITestCase, URLPatternsTestCase
from rest_framework.versioning import NamespaceVersioning
class RequestVersionView(APIView):
def get(self, request, *args, **kwargs):
- return Response({'version': request.version})
+ return Response({"version": request.version})
class ReverseView(APIView):
def get(self, request, *args, **kwargs):
- return Response({'url': reverse('another', request=request)})
+ return Response({"url": reverse("another", request=request)})
class AllowedVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
- scheme.allowed_versions = ('v1', 'v2')
+ scheme.allowed_versions = ("v1", "v2")
return (scheme.determine_version(request, *args, **kwargs), scheme)
class AllowedAndDefaultVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
- scheme.allowed_versions = ('v1', 'v2')
- scheme.default_version = 'v2'
+ scheme.allowed_versions = ("v1", "v2")
+ scheme.default_version = "v2"
return (scheme.determine_version(request, *args, **kwargs), scheme)
class AllowedWithNoneVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
- scheme.allowed_versions = ('v1', 'v2', None)
+ scheme.allowed_versions = ("v1", "v2", None)
return (scheme.determine_version(request, *args, **kwargs), scheme)
class AllowedWithNoneAndDefaultVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
- scheme.allowed_versions = ('v1', 'v2', None)
- scheme.default_version = 'v2'
+ scheme.allowed_versions = ("v1", "v2", None)
+ scheme.default_version = "v2"
return (scheme.determine_version(request, *args, **kwargs), scheme)
@@ -68,151 +66,153 @@ class TestRequestVersion:
def test_unversioned(self):
view = RequestVersionView.as_view()
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
def test_query_param_versioning(self):
scheme = versioning.QueryParameterVersioning
view = RequestVersionView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/?version=1.2.3')
+ request = factory.get("/endpoint/?version=1.2.3")
response = view(request)
- assert response.data == {'version': '1.2.3'}
+ assert response.data == {"version": "1.2.3"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
- @override_settings(ALLOWED_HOSTS=['*'])
+ @override_settings(ALLOWED_HOSTS=["*"])
def test_host_name_versioning(self):
scheme = versioning.HostNameVersioning
view = RequestVersionView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ request = factory.get("/endpoint/", HTTP_HOST="v1.example.org")
response = view(request)
- assert response.data == {'version': 'v1'}
+ assert response.data == {"version": "v1"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
def test_accept_header_versioning(self):
scheme = versioning.AcceptHeaderVersioning
view = RequestVersionView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=1.2.3')
+ request = factory.get(
+ "/endpoint/", HTTP_ACCEPT="application/json; version=1.2.3"
+ )
response = view(request)
- assert response.data == {'version': '1.2.3'}
+ assert response.data == {"version": "1.2.3"}
- request = factory.get('/endpoint/', HTTP_ACCEPT='*/*; version=1.2.3')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="*/*; version=1.2.3")
response = view(request)
- assert response.data == {'version': '1.2.3'}
+ assert response.data == {"version": "1.2.3"}
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
def test_url_path_versioning(self):
scheme = versioning.URLPathVersioning
view = RequestVersionView.as_view(versioning_class=scheme)
- request = factory.get('/1.2.3/endpoint/')
- response = view(request, version='1.2.3')
- assert response.data == {'version': '1.2.3'}
+ request = factory.get("/1.2.3/endpoint/")
+ response = view(request, version="1.2.3")
+ assert response.data == {"version": "1.2.3"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
def test_namespace_versioning(self):
class FakeResolverMatch:
- namespace = 'v1'
+ namespace = "v1"
scheme = versioning.NamespaceVersioning
view = RequestVersionView.as_view(versioning_class=scheme)
- request = factory.get('/v1/endpoint/')
+ request = factory.get("/v1/endpoint/")
request.resolver_match = FakeResolverMatch
- response = view(request, version='v1')
- assert response.data == {'version': 'v1'}
+ response = view(request, version="v1")
+ assert response.data == {"version": "v1"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
class TestURLReversing(URLPatternsTestCase, APITestCase):
included = [
- url(r'^namespaced/$', dummy_view, name='another'),
- url(r'^example/(?P\d+)/$', dummy_pk_view, name='example-detail')
+ url(r"^namespaced/$", dummy_view, name="another"),
+ url(r"^example/(?P\d+)/$", dummy_pk_view, name="example-detail"),
]
urlpatterns = [
- url(r'^v1/', include((included, 'v1'), namespace='v1')),
- url(r'^another/$', dummy_view, name='another'),
- url(r'^(?P[v1|v2]+)/another/$', dummy_view, name='another'),
+ url(r"^v1/", include((included, "v1"), namespace="v1")),
+ url(r"^another/$", dummy_view, name="another"),
+ url(r"^(?P[v1|v2]+)/another/$", dummy_view, name="another"),
]
def test_reverse_unversioned(self):
view = ReverseView.as_view()
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/'}
+ assert response.data == {"url": "http://testserver/another/"}
def test_reverse_query_param_versioning(self):
scheme = versioning.QueryParameterVersioning
view = ReverseView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/?version=v1')
+ request = factory.get("/endpoint/?version=v1")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/?version=v1'}
+ assert response.data == {"url": "http://testserver/another/?version=v1"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/'}
+ assert response.data == {"url": "http://testserver/another/"}
- @override_settings(ALLOWED_HOSTS=['*'])
+ @override_settings(ALLOWED_HOSTS=["*"])
def test_reverse_host_name_versioning(self):
scheme = versioning.HostNameVersioning
view = ReverseView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_HOST='v1.example.org')
+ request = factory.get("/endpoint/", HTTP_HOST="v1.example.org")
response = view(request)
- assert response.data == {'url': 'http://v1.example.org/another/'}
+ assert response.data == {"url": "http://v1.example.org/another/"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/'}
+ assert response.data == {"url": "http://testserver/another/"}
def test_reverse_url_path_versioning(self):
scheme = versioning.URLPathVersioning
view = ReverseView.as_view(versioning_class=scheme)
- request = factory.get('/v1/endpoint/')
- response = view(request, version='v1')
- assert response.data == {'url': 'http://testserver/v1/another/'}
+ request = factory.get("/v1/endpoint/")
+ response = view(request, version="v1")
+ assert response.data == {"url": "http://testserver/v1/another/"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/'}
+ assert response.data == {"url": "http://testserver/another/"}
def test_reverse_namespace_versioning(self):
class FakeResolverMatch:
- namespace = 'v1'
+ namespace = "v1"
scheme = versioning.NamespaceVersioning
view = ReverseView.as_view(versioning_class=scheme)
- request = factory.get('/v1/endpoint/')
+ request = factory.get("/v1/endpoint/")
request.resolver_match = FakeResolverMatch
- response = view(request, version='v1')
- assert response.data == {'url': 'http://testserver/v1/namespaced/'}
+ response = view(request, version="v1")
+ assert response.data == {"url": "http://testserver/v1/namespaced/"}
- request = factory.get('/endpoint/')
+ request = factory.get("/endpoint/")
response = view(request)
- assert response.data == {'url': 'http://testserver/another/'}
+ assert response.data == {"url": "http://testserver/another/"}
class TestInvalidVersion:
@@ -220,16 +220,16 @@ class TestInvalidVersion:
scheme = versioning.QueryParameterVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/?version=v3')
+ request = factory.get("/endpoint/?version=v3")
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
- @override_settings(ALLOWED_HOSTS=['*'])
+ @override_settings(ALLOWED_HOSTS=["*"])
def test_invalid_host_name_versioning(self):
scheme = versioning.HostNameVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_HOST='v3.example.org')
+ request = factory.get("/endpoint/", HTTP_HOST="v3.example.org")
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -237,7 +237,7 @@ class TestInvalidVersion:
scheme = versioning.AcceptHeaderVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json; version=v3")
response = view(request)
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
@@ -245,20 +245,20 @@ class TestInvalidVersion:
scheme = versioning.URLPathVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/v3/endpoint/')
- response = view(request, version='v3')
+ request = factory.get("/v3/endpoint/")
+ response = view(request, version="v3")
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_invalid_namespace_versioning(self):
class FakeResolverMatch:
- namespace = 'v3'
+ namespace = "v3"
scheme = versioning.NamespaceVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/v3/endpoint/')
+ request = factory.get("/v3/endpoint/")
request.resolver_match = FakeResolverMatch
- response = view(request, version='v3')
+ response = view(request, version="v3")
assert response.status_code == status.HTTP_404_NOT_FOUND
@@ -267,7 +267,7 @@ class TestAllowedAndDefaultVersion:
scheme = versioning.AcceptHeaderVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json")
response = view(request)
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
@@ -275,17 +275,16 @@ class TestAllowedAndDefaultVersion:
scheme = versioning.AcceptHeaderVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'version': 'v2'}
+ assert response.data == {"version": "v2"}
def test_with_default(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/',
- HTTP_ACCEPT='application/json; version=v2')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json; version=v2")
response = view(request)
assert response.status_code == status.HTTP_200_OK
@@ -293,29 +292,27 @@ class TestAllowedAndDefaultVersion:
scheme = versioning.AcceptHeaderVersioning
view = AllowedWithNoneVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'version': None}
+ assert response.data == {"version": None}
def test_missing_with_default_and_none_allowed(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedWithNoneAndDefaultVersionsView.as_view(versioning_class=scheme)
- request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
+ request = factory.get("/endpoint/", HTTP_ACCEPT="application/json")
response = view(request)
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'version': 'v2'}
+ assert response.data == {"version": "v2"}
class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase):
- included = [
- url(r'^namespaced/(?P\d+)/$', dummy_pk_view, name='namespaced'),
- ]
+ included = [url(r"^namespaced/(?P\d+)/$", dummy_pk_view, name="namespaced")]
urlpatterns = [
- url(r'^v1/', include((included, 'v1'), namespace='v1')),
- url(r'^v2/', include((included, 'v2'), namespace='v2'))
+ url(r"^v1/", include((included, "v1"), namespace="v1")),
+ url(r"^v2/", include((included, "v2"), namespace="v2")),
]
def setUp(self):
@@ -323,36 +320,38 @@ class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase):
class MockQueryset(object):
def get(self, pk):
- return 'object %s' % pk
+ return "object %s" % pk
self.field = serializers.HyperlinkedRelatedField(
- view_name='namespaced',
- queryset=MockQueryset()
+ view_name="namespaced", queryset=MockQueryset()
)
- request = factory.get('/')
+ request = factory.get("/")
request.versioning_scheme = NamespaceVersioning()
- request.version = 'v1'
- self.field._context = {'request': request}
+ request.version = "v1"
+ self.field._context = {"request": request}
def test_bug_2489(self):
- assert self.field.to_internal_value('/v1/namespaced/3/') == 'object 3'
+ assert self.field.to_internal_value("/v1/namespaced/3/") == "object 3"
with pytest.raises(serializers.ValidationError):
- self.field.to_internal_value('/v2/namespaced/3/')
+ self.field.to_internal_value("/v2/namespaced/3/")
-class TestNamespaceVersioningHyperlinkedRelatedFieldScheme(URLPatternsTestCase, APITestCase):
- nested = [
- url(r'^namespaced/(?P\d+)/$', dummy_pk_view, name='nested'),
- ]
+class TestNamespaceVersioningHyperlinkedRelatedFieldScheme(
+ URLPatternsTestCase, APITestCase
+):
+ nested = [url(r"^namespaced/(?P\d+)/$", dummy_pk_view, name="nested")]
included = [
- url(r'^namespaced/(?P\d+)/$', dummy_pk_view, name='namespaced'),
- url(r'^nested/', include((nested, 'nested-namespace'), namespace='nested-namespace'))
+ url(r"^namespaced/(?P\d+)/$", dummy_pk_view, name="namespaced"),
+ url(
+ r"^nested/",
+ include((nested, "nested-namespace"), namespace="nested-namespace"),
+ ),
]
urlpatterns = [
- url(r'^v1/', include((included, 'restframeworkv1'), namespace='v1')),
- url(r'^v2/', include((included, 'restframeworkv2'), namespace='v2')),
- url(r'^non-api/(?P\d+)/$', dummy_pk_view, name='non-api-view')
+ url(r"^v1/", include((included, "restframeworkv1"), namespace="v1")),
+ url(r"^v2/", include((included, "restframeworkv2"), namespace="v2")),
+ url(r"^non-api/(?P\d+)/$", dummy_pk_view, name="non-api-view"),
]
def _create_field(self, view_name, version):
@@ -360,30 +359,41 @@ class TestNamespaceVersioningHyperlinkedRelatedFieldScheme(URLPatternsTestCase,
request.versioning_scheme = NamespaceVersioning()
request.version = version
- field = serializers.HyperlinkedRelatedField(
- view_name=view_name,
- read_only=True)
- field._context = {'request': request}
+ field = serializers.HyperlinkedRelatedField(view_name=view_name, read_only=True)
+ field._context = {"request": request}
return field
def test_api_url_is_properly_reversed_with_v1(self):
- field = self._create_field('namespaced', 'v1')
- assert field.to_representation(PKOnlyObject(3)) == 'http://testserver/v1/namespaced/3/'
+ field = self._create_field("namespaced", "v1")
+ assert (
+ field.to_representation(PKOnlyObject(3))
+ == "http://testserver/v1/namespaced/3/"
+ )
def test_api_url_is_properly_reversed_with_v2(self):
- field = self._create_field('namespaced', 'v2')
- assert field.to_representation(PKOnlyObject(5)) == 'http://testserver/v2/namespaced/5/'
+ field = self._create_field("namespaced", "v2")
+ assert (
+ field.to_representation(PKOnlyObject(5))
+ == "http://testserver/v2/namespaced/5/"
+ )
def test_api_url_is_properly_reversed_with_nested(self):
- field = self._create_field('nested', 'v1:nested-namespace')
- assert field.to_representation(PKOnlyObject(3)) == 'http://testserver/v1/nested/namespaced/3/'
+ field = self._create_field("nested", "v1:nested-namespace")
+ assert (
+ field.to_representation(PKOnlyObject(3))
+ == "http://testserver/v1/nested/namespaced/3/"
+ )
def test_non_api_url_is_properly_reversed_regardless_of_the_version(self):
"""
Regression test for #2711
"""
- field = self._create_field('non-api-view', 'v1')
- assert field.to_representation(PKOnlyObject(10)) == 'http://testserver/non-api/10/'
+ field = self._create_field("non-api-view", "v1")
+ assert (
+ field.to_representation(PKOnlyObject(10)) == "http://testserver/non-api/10/"
+ )
- field = self._create_field('non-api-view', 'v2')
- assert field.to_representation(PKOnlyObject(10)) == 'http://testserver/non-api/10/'
+ field = self._create_field("non-api-view", "v2")
+ assert (
+ field.to_representation(PKOnlyObject(10)) == "http://testserver/non-api/10/"
+ )
diff --git a/tests/test_views.py b/tests/test_views.py
index f0919e846..800318a21 100644
--- a/tests/test_views.py
+++ b/tests/test_views.py
@@ -12,32 +12,33 @@ from rest_framework.settings import APISettings, api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
+
factory = APIRequestFactory()
if sys.version_info[:2] >= (3, 4):
- JSON_ERROR = 'JSON parse error - Expecting value:'
+ JSON_ERROR = "JSON parse error - Expecting value:"
else:
- JSON_ERROR = 'JSON parse error - No JSON object could be decoded'
+ JSON_ERROR = "JSON parse error - No JSON object could be decoded"
class BasicView(APIView):
def get(self, request, *args, **kwargs):
- return Response({'method': 'GET'})
+ return Response({"method": "GET"})
def post(self, request, *args, **kwargs):
- return Response({'method': 'POST', 'data': request.data})
+ return Response({"method": "POST", "data": request.data})
-@api_view(['GET', 'POST', 'PUT', 'PATCH'])
+@api_view(["GET", "POST", "PUT", "PATCH"])
def basic_view(request):
- if request.method == 'GET':
- return {'method': 'GET'}
- elif request.method == 'POST':
- return {'method': 'POST', 'data': request.data}
- elif request.method == 'PUT':
- return {'method': 'PUT', 'data': request.data}
- elif request.method == 'PATCH':
- return {'method': 'PATCH', 'data': request.data}
+ if request.method == "GET":
+ return {"method": "GET"}
+ elif request.method == "POST":
+ return {"method": "POST", "data": request.data}
+ elif request.method == "PUT":
+ return {"method": "PUT", "data": request.data}
+ elif request.method == "PATCH":
+ return {"method": "PATCH", "data": request.data}
class ErrorView(APIView):
@@ -47,18 +48,18 @@ class ErrorView(APIView):
def custom_handler(exc, context):
if isinstance(exc, SyntaxError):
- return Response({'error': 'SyntaxError'}, status=400)
- return Response({'error': 'UnknownError'}, status=500)
+ return Response({"error": "SyntaxError"}, status=400)
+ return Response({"error": "UnknownError"}, status=500)
class OverridenSettingsView(APIView):
- settings = APISettings({'EXCEPTION_HANDLER': custom_handler})
+ settings = APISettings({"EXCEPTION_HANDLER": custom_handler})
def get(self, request, *args, **kwargs):
- raise SyntaxError('request is invalid syntax')
+ raise SyntaxError("request is invalid syntax")
-@api_view(['GET'])
+@api_view(["GET"])
def error_view(request):
raise Exception
@@ -70,7 +71,7 @@ def sanitise_json_error(error_dict):
"""
ret = copy.copy(error_dict)
chop = len(JSON_ERROR)
- ret['detail'] = ret['detail'][:chop]
+ ret["detail"] = ret["detail"][:chop]
return ret
@@ -79,11 +80,9 @@ class ClassBasedViewIntegrationTests(TestCase):
self.view = BasicView.as_view()
def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
+ request = factory.post("/", "f00bar", content_type="application/json")
response = self.view(request)
- expected = {
- 'detail': JSON_ERROR
- }
+ expected = {"detail": JSON_ERROR}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert sanitise_json_error(response.data) == expected
@@ -93,11 +92,9 @@ class FunctionBasedViewIntegrationTests(TestCase):
self.view = basic_view
def test_400_parse_error(self):
- request = factory.post('/', 'f00bar', content_type='application/json')
+ request = factory.post("/", "f00bar", content_type="application/json")
response = self.view(request)
- expected = {
- 'detail': JSON_ERROR
- }
+ expected = {"detail": JSON_ERROR}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert sanitise_json_error(response.data) == expected
@@ -107,7 +104,7 @@ class TestCustomExceptionHandler(TestCase):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
def exception_handler(exc, request):
- return Response('Error!', status=status.HTTP_400_BAD_REQUEST)
+ return Response("Error!", status=status.HTTP_400_BAD_REQUEST)
api_settings.EXCEPTION_HANDLER = exception_handler
@@ -117,18 +114,18 @@ class TestCustomExceptionHandler(TestCase):
def test_class_based_view_exception_handler(self):
view = ErrorView.as_view()
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
- expected = 'Error!'
+ expected = "Error!"
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == expected
def test_function_based_view_exception_handler(self):
view = error_view
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = view(request)
- expected = 'Error!'
+ expected = "Error!"
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.data == expected
@@ -138,7 +135,7 @@ class TestCustomSettings(TestCase):
self.view = OverridenSettingsView.as_view()
def test_get_exception_handler(self):
- request = factory.get('/', content_type='application/json')
+ request = factory.get("/", content_type="application/json")
response = self.view(request)
assert response.status_code == 400
- assert response.data == {'error': 'SyntaxError'}
+ assert response.data == {"error": "SyntaxError"}
diff --git a/tests/test_viewsets.py b/tests/test_viewsets.py
index eac36f095..3a125de99 100644
--- a/tests/test_viewsets.py
+++ b/tests/test_viewsets.py
@@ -12,21 +12,21 @@ from rest_framework.routers import SimpleRouter
from rest_framework.test import APIRequestFactory
from rest_framework.viewsets import GenericViewSet
+
factory = APIRequestFactory()
class BasicViewSet(GenericViewSet):
def list(self, request, *args, **kwargs):
- return Response({'ACTION': 'LIST'})
+ return Response({"ACTION": "LIST"})
class InstanceViewSet(GenericViewSet):
-
def dispatch(self, request, *args, **kwargs):
return self.dummy(request, *args, **kwargs)
def dummy(self, request, *args, **kwargs):
- return Response({'view': self})
+ return Response({"view": self})
class Action(models.Model):
@@ -46,7 +46,7 @@ class ActionViewSet(GenericViewSet):
def list_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=False, url_name='list-custom')
+ @action(detail=False, url_name="list-custom")
def custom_list_action(self, request, *args, **kwargs):
raise NotImplementedError
@@ -54,17 +54,16 @@ class ActionViewSet(GenericViewSet):
def detail_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, url_name='detail-custom')
+ @action(detail=True, url_name="detail-custom")
def custom_detail_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, url_path=r'unresolvable/(?P\w+)', url_name='unresolvable')
+ @action(detail=True, url_path=r"unresolvable/(?P\w+)", url_name="unresolvable")
def unresolvable_detail_action(self, request, *args, **kwargs):
raise NotImplementedError
class ActionNamesViewSet(GenericViewSet):
-
def retrieve(self, request, *args, **kwargs):
return Response()
@@ -72,42 +71,36 @@ class ActionNamesViewSet(GenericViewSet):
def unnamed_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, name='Custom Name')
+ @action(detail=True, name="Custom Name")
def named_action(self, request, *args, **kwargs):
raise NotImplementedError
- @action(detail=True, suffix='Custom Suffix')
+ @action(detail=True, suffix="Custom Suffix")
def suffixed_action(self, request, *args, **kwargs):
raise NotImplementedError
router = SimpleRouter()
-router.register(r'actions', ActionViewSet)
-router.register(r'actions-alt', ActionViewSet, basename='actions-alt')
-router.register(r'names', ActionNamesViewSet, basename='names')
+router.register(r"actions", ActionViewSet)
+router.register(r"actions-alt", ActionViewSet, basename="actions-alt")
+router.register(r"names", ActionNamesViewSet, basename="names")
-urlpatterns = [
- url(r'^api/', include(router.urls)),
-]
+urlpatterns = [url(r"^api/", include(router.urls))]
class InitializeViewSetsTestCase(TestCase):
def test_initialize_view_set_with_actions(self):
- request = factory.get('/', '', content_type='application/json')
- my_view = BasicViewSet.as_view(actions={
- 'get': 'list',
- })
+ request = factory.get("/", "", content_type="application/json")
+ my_view = BasicViewSet.as_view(actions={"get": "list"})
response = my_view(request)
assert response.status_code == status.HTTP_200_OK
- assert response.data == {'ACTION': 'LIST'}
+ assert response.data == {"ACTION": "LIST"}
def testhead_request_against_viewset(self):
- request = factory.head('/', '', content_type='application/json')
- my_view = BasicViewSet.as_view(actions={
- 'get': 'list',
- })
+ request = factory.head("/", "", content_type="application/json")
+ my_view = BasicViewSet.as_view(actions={"get": "list"})
response = my_view(request)
assert response.status_code == status.HTTP_200_OK
@@ -119,17 +112,17 @@ class InitializeViewSetsTestCase(TestCase):
assert str(excinfo.value) == (
"The `actions` argument must be provided "
"when calling `.as_view()` on a ViewSet. "
- "For example `.as_view({'get': 'list'})`")
+ "For example `.as_view({'get': 'list'})`"
+ )
def test_initialize_view_set_with_both_name_and_suffix(self):
with pytest.raises(TypeError) as excinfo:
- BasicViewSet.as_view(name='', suffix='', actions={
- 'get': 'list',
- })
+ BasicViewSet.as_view(name="", suffix="", actions={"get": "list"})
assert str(excinfo.value) == (
"BasicViewSet() received both `name` and `suffix`, "
- "which are mutually exclusive arguments.")
+ "which are mutually exclusive arguments."
+ )
def test_args_kwargs_request_action_map_on_self(self):
"""
@@ -137,54 +130,62 @@ class InitializeViewSetsTestCase(TestCase):
once `as_view` has been called.
"""
bare_view = InstanceViewSet()
- view = InstanceViewSet.as_view(actions={
- 'get': 'dummy',
- })(factory.get('/')).data['view']
+ view = InstanceViewSet.as_view(actions={"get": "dummy"})(factory.get("/")).data[
+ "view"
+ ]
- for attribute in ('args', 'kwargs', 'request', 'action_map'):
+ for attribute in ("args", "kwargs", "request", "action_map"):
self.assertNotIn(attribute, dir(bare_view))
self.assertIn(attribute, dir(view))
class GetExtraActionsTests(TestCase):
-
def test_extra_actions(self):
view = ActionViewSet()
actual = [action.__name__ for action in view.get_extra_actions()]
expected = [
- 'custom_detail_action',
- 'custom_list_action',
- 'detail_action',
- 'list_action',
- 'unresolvable_detail_action',
+ "custom_detail_action",
+ "custom_list_action",
+ "detail_action",
+ "list_action",
+ "unresolvable_detail_action",
]
self.assertEqual(actual, expected)
-@override_settings(ROOT_URLCONF='tests.test_viewsets')
+@override_settings(ROOT_URLCONF="tests.test_viewsets")
class GetExtraActionUrlMapTests(TestCase):
-
def test_list_view(self):
- response = self.client.get('/api/actions/')
- view = response.renderer_context['view']
+ response = self.client.get("/api/actions/")
+ view = response.renderer_context["view"]
- expected = OrderedDict([
- ('Custom list action', 'http://testserver/api/actions/custom_list_action/'),
- ('List action', 'http://testserver/api/actions/list_action/'),
- ])
+ expected = OrderedDict(
+ [
+ (
+ "Custom list action",
+ "http://testserver/api/actions/custom_list_action/",
+ ),
+ ("List action", "http://testserver/api/actions/list_action/"),
+ ]
+ )
self.assertEqual(view.get_extra_action_url_map(), expected)
def test_detail_view(self):
- response = self.client.get('/api/actions/1/')
- view = response.renderer_context['view']
+ response = self.client.get("/api/actions/1/")
+ view = response.renderer_context["view"]
- expected = OrderedDict([
- ('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'),
- ('Detail action', 'http://testserver/api/actions/1/detail_action/'),
- # "Unresolvable detail action" excluded, since it's not resolvable
- ])
+ expected = OrderedDict(
+ [
+ (
+ "Custom detail action",
+ "http://testserver/api/actions/1/custom_detail_action/",
+ ),
+ ("Detail action", "http://testserver/api/actions/1/detail_action/"),
+ # "Unresolvable detail action" excluded, since it's not resolvable
+ ]
+ )
self.assertEqual(view.get_extra_action_url_map(), expected)
@@ -193,53 +194,72 @@ class GetExtraActionUrlMapTests(TestCase):
def test_action_names(self):
# Action 'name' and 'suffix' kwargs should be respected
- response = self.client.get('/api/names/1/')
- view = response.renderer_context['view']
+ response = self.client.get("/api/names/1/")
+ view = response.renderer_context["view"]
- expected = OrderedDict([
- ('Custom Name', 'http://testserver/api/names/1/named_action/'),
- ('Action Names Custom Suffix', 'http://testserver/api/names/1/suffixed_action/'),
- ('Unnamed action', 'http://testserver/api/names/1/unnamed_action/'),
- ])
+ expected = OrderedDict(
+ [
+ ("Custom Name", "http://testserver/api/names/1/named_action/"),
+ (
+ "Action Names Custom Suffix",
+ "http://testserver/api/names/1/suffixed_action/",
+ ),
+ ("Unnamed action", "http://testserver/api/names/1/unnamed_action/"),
+ ]
+ )
self.assertEqual(view.get_extra_action_url_map(), expected)
-@override_settings(ROOT_URLCONF='tests.test_viewsets')
+@override_settings(ROOT_URLCONF="tests.test_viewsets")
class ReverseActionTests(TestCase):
def test_default_basename(self):
view = ActionViewSet()
view.basename = router.get_default_basename(ActionViewSet)
view.request = None
- assert view.reverse_action('list') == '/api/actions/'
- assert view.reverse_action('list-action') == '/api/actions/list_action/'
- assert view.reverse_action('list-custom') == '/api/actions/custom_list_action/'
+ assert view.reverse_action("list") == "/api/actions/"
+ assert view.reverse_action("list-action") == "/api/actions/list_action/"
+ assert view.reverse_action("list-custom") == "/api/actions/custom_list_action/"
- assert view.reverse_action('detail', args=['1']) == '/api/actions/1/'
- assert view.reverse_action('detail-action', args=['1']) == '/api/actions/1/detail_action/'
- assert view.reverse_action('detail-custom', args=['1']) == '/api/actions/1/custom_detail_action/'
+ assert view.reverse_action("detail", args=["1"]) == "/api/actions/1/"
+ assert (
+ view.reverse_action("detail-action", args=["1"])
+ == "/api/actions/1/detail_action/"
+ )
+ assert (
+ view.reverse_action("detail-custom", args=["1"])
+ == "/api/actions/1/custom_detail_action/"
+ )
def test_custom_basename(self):
view = ActionViewSet()
- view.basename = 'actions-alt'
+ view.basename = "actions-alt"
view.request = None
- assert view.reverse_action('list') == '/api/actions-alt/'
- assert view.reverse_action('list-action') == '/api/actions-alt/list_action/'
- assert view.reverse_action('list-custom') == '/api/actions-alt/custom_list_action/'
+ assert view.reverse_action("list") == "/api/actions-alt/"
+ assert view.reverse_action("list-action") == "/api/actions-alt/list_action/"
+ assert (
+ view.reverse_action("list-custom") == "/api/actions-alt/custom_list_action/"
+ )
- assert view.reverse_action('detail', args=['1']) == '/api/actions-alt/1/'
- assert view.reverse_action('detail-action', args=['1']) == '/api/actions-alt/1/detail_action/'
- assert view.reverse_action('detail-custom', args=['1']) == '/api/actions-alt/1/custom_detail_action/'
+ assert view.reverse_action("detail", args=["1"]) == "/api/actions-alt/1/"
+ assert (
+ view.reverse_action("detail-action", args=["1"])
+ == "/api/actions-alt/1/detail_action/"
+ )
+ assert (
+ view.reverse_action("detail-custom", args=["1"])
+ == "/api/actions-alt/1/custom_detail_action/"
+ )
def test_request_passing(self):
view = ActionViewSet()
view.basename = router.get_default_basename(ActionViewSet)
- view.request = factory.get('/')
+ view.request = factory.get("/")
# Passing the view's request object should result in an absolute URL.
- assert view.reverse_action('list') == 'http://testserver/api/actions/'
+ assert view.reverse_action("list") == "http://testserver/api/actions/"
# Users should be able to explicitly not pass the view's request.
- assert view.reverse_action('list', request=None) == '/api/actions/'
+ assert view.reverse_action("list", request=None) == "/api/actions/"
diff --git a/tests/test_write_only_fields.py b/tests/test_write_only_fields.py
index fd712f837..e631a33b5 100644
--- a/tests/test_write_only_fields.py
+++ b/tests/test_write_only_fields.py
@@ -12,18 +12,12 @@ class WriteOnlyFieldTests(TestCase):
self.Serializer = ExampleSerializer
def test_write_only_fields_are_present_on_input(self):
- data = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
+ data = {"email": "foo@example.com", "password": "123"}
serializer = self.Serializer(data=data)
assert serializer.is_valid()
assert serializer.validated_data == data
def test_write_only_fields_are_not_present_on_output(self):
- instance = {
- 'email': 'foo@example.com',
- 'password': '123'
- }
+ instance = {"email": "foo@example.com", "password": "123"}
serializer = self.Serializer(instance)
- assert serializer.data == {'email': 'foo@example.com'}
+ assert serializer.data == {"email": "foo@example.com"}
diff --git a/tests/urls.py b/tests/urls.py
index 76ada5e3d..29ecf35c9 100644
--- a/tests/urls.py
+++ b/tests/urls.py
@@ -8,9 +8,8 @@ from django.conf.urls import url
from rest_framework.compat import coreapi
from rest_framework.documentation import include_docs_urls
+
if coreapi:
- urlpatterns = [
- url(r'^docs/', include_docs_urls(title='Test Suite API')),
- ]
+ urlpatterns = [url(r"^docs/", include_docs_urls(title="Test Suite API"))]
else:
urlpatterns = []
diff --git a/tests/utils.py b/tests/utils.py
index 509e6a102..2c7d2eb6f 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -9,11 +9,10 @@ class MockObject(object):
setattr(self, key, val)
def __str__(self):
- kwargs_str = ', '.join([
- '%s=%s' % (key, value)
- for key, value in sorted(self._kwargs.items())
- ])
- return '' % kwargs_str
+ kwargs_str = ", ".join(
+ ["%s=%s" % (key, value) for key, value in sorted(self._kwargs.items())]
+ )
+ return "" % kwargs_str
class MockQueryset(object):
@@ -25,10 +24,9 @@ class MockQueryset(object):
def get(self, **lookup):
for item in self.items:
- if all([
- getattr(item, key, None) == value
- for key, value in lookup.items()
- ]):
+ if all(
+ [getattr(item, key, None) == value for key, value in lookup.items()]
+ ):
return item
raise ObjectDoesNotExist()
@@ -39,6 +37,7 @@ class BadType(object):
will raise a `TypeError`, as occurs in Django when making
queryset lookups with an incorrect type for the lookup value.
"""
+
def __eq__(self):
raise TypeError()
@@ -46,10 +45,10 @@ class BadType(object):
def mock_reverse(view_name, args=None, kwargs=None, request=None, format=None):
args = args or []
kwargs = kwargs or {}
- value = (args + list(kwargs.values()) + ['-'])[0]
- prefix = 'http://example.org' if request else ''
- suffix = ('.' + format) if (format is not None) else ''
- return '%s/%s/%s%s/' % (prefix, view_name, value, suffix)
+ value = (args + list(kwargs.values()) + ["-"])[0]
+ prefix = "http://example.org" if request else ""
+ suffix = ("." + format) if (format is not None) else ""
+ return "%s/%s/%s%s/" % (prefix, view_name, value, suffix)
def fail_reverse(view_name, args=None, kwargs=None, request=None, format=None):