black/flake8/isort fix

This commit is contained in:
Rizwan Mansuri 2019-04-13 12:16:48 +01:00
parent f2eacd3660
commit f418b4e9ca
134 changed files with 10876 additions and 8988 deletions

View File

@ -4,7 +4,7 @@ flake8-tidy-imports==1.1.0
pycodestyle==2.3.1 pycodestyle==2.3.1
# Sort and lint imports # Sort and lint imports
isort==4.3.3 isort==4.3.17
# black # black
black==19.3b0 black==19.3b0

View File

@ -7,22 +7,22 @@ ______ _____ _____ _____ __
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_| \_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
""" """
__title__ = 'Django REST framework' __title__ = "Django REST framework"
__version__ = '3.9.2' __version__ = "3.9.2"
__author__ = 'Tom Christie' __author__ = "Tom Christie"
__license__ = 'BSD 2-Clause' __license__ = "BSD 2-Clause"
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd' __copyright__ = "Copyright 2011-2019 Encode OSS Ltd"
# Version synonym # Version synonym
VERSION = __version__ VERSION = __version__
# Header encoding (see RFC5987) # Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1' HTTP_HEADER_ENCODING = "iso-8859-1"
# Default datetime input and output formats # Default datetime input and output formats
ISO_8601 = 'iso-8601' ISO_8601 = "iso-8601"
default_app_config = 'rest_framework.apps.RestFrameworkConfig' default_app_config = "rest_framework.apps.RestFrameworkConfig"
class RemovedInDRF310Warning(DeprecationWarning): class RemovedInDRF310Warning(DeprecationWarning):

View File

@ -2,7 +2,7 @@ from django.apps import AppConfig
class RestFrameworkConfig(AppConfig): class RestFrameworkConfig(AppConfig):
name = 'rest_framework' name = "rest_framework"
verbose_name = "Django REST framework" verbose_name = "Django REST framework"
def ready(self): def ready(self):

View File

@ -20,7 +20,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode. Hide some test client ickyness where the header can be unicode.
""" """
auth = request.META.get('HTTP_AUTHORIZATION', b'') auth = request.META.get("HTTP_AUTHORIZATION", b"")
if isinstance(auth, text_type): if isinstance(auth, text_type):
# Work around django test client oddness # Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING) auth = auth.encode(HTTP_HEADER_ENCODING)
@ -57,7 +57,8 @@ class BasicAuthentication(BaseAuthentication):
""" """
HTTP Basic authentication against username/password. HTTP Basic authentication against username/password.
""" """
www_authenticate_realm = 'api'
www_authenticate_realm = "api"
def authenticate(self, request): def authenticate(self, request):
""" """
@ -66,20 +67,24 @@ class BasicAuthentication(BaseAuthentication):
""" """
auth = get_authorization_header(request).split() auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'basic': if not auth or auth[0].lower() != b"basic":
return None return None
if len(auth) == 1: if len(auth) == 1:
msg = _('Invalid basic header. No credentials provided.') msg = _("Invalid basic header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2: elif len(auth) > 2:
msg = _('Invalid basic header. Credentials string should not contain spaces.') msg = _(
"Invalid basic header. Credentials string should not contain spaces."
)
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
try: try:
auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':') auth_parts = (
base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(":")
)
except (TypeError, UnicodeDecodeError, binascii.Error): except (TypeError, UnicodeDecodeError, binascii.Error):
msg = _('Invalid basic header. Credentials not correctly base64 encoded.') msg = _("Invalid basic header. Credentials not correctly base64 encoded.")
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
userid, password = auth_parts[0], auth_parts[2] userid, password = auth_parts[0], auth_parts[2]
@ -90,17 +95,14 @@ class BasicAuthentication(BaseAuthentication):
Authenticate the userid and password against username and password Authenticate the userid and password against username and password
with optional request for context. with optional request for context.
""" """
credentials = { credentials = {get_user_model().USERNAME_FIELD: userid, "password": password}
get_user_model().USERNAME_FIELD: userid,
'password': password
}
user = authenticate(request=request, **credentials) user = authenticate(request=request, **credentials)
if user is None: if user is None:
raise exceptions.AuthenticationFailed(_('Invalid username/password.')) raise exceptions.AuthenticationFailed(_("Invalid username/password."))
if not user.is_active: if not user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (user, None) return (user, None)
@ -120,7 +122,7 @@ class SessionAuthentication(BaseAuthentication):
""" """
# Get the session-based user from the underlying HttpRequest object # Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None) user = getattr(request._request, "user", None)
# Unauthenticated, CSRF validation not required # Unauthenticated, CSRF validation not required
if not user or not user.is_active: if not user or not user.is_active:
@ -141,7 +143,7 @@ class SessionAuthentication(BaseAuthentication):
reason = check.process_view(request, None, (), {}) reason = check.process_view(request, None, (), {})
if reason: if reason:
# CSRF failed, bail with explicit error message # CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) raise exceptions.PermissionDenied("CSRF Failed: %s" % reason)
class TokenAuthentication(BaseAuthentication): class TokenAuthentication(BaseAuthentication):
@ -154,13 +156,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
""" """
keyword = 'Token' keyword = "Token"
model = None model = None
def get_model(self): def get_model(self):
if self.model is not None: if self.model is not None:
return self.model return self.model
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
return Token return Token
""" """
@ -177,16 +180,18 @@ class TokenAuthentication(BaseAuthentication):
return None return None
if len(auth) == 1: if len(auth) == 1:
msg = _('Invalid token header. No credentials provided.') msg = _("Invalid token header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2: elif len(auth) > 2:
msg = _('Invalid token header. Token string should not contain spaces.') msg = _("Invalid token header. Token string should not contain spaces.")
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
try: try:
token = auth[1].decode() token = auth[1].decode()
except UnicodeError: except UnicodeError:
msg = _('Invalid token header. Token string should not contain invalid characters.') msg = _(
"Invalid token header. Token string should not contain invalid characters."
)
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token) return self.authenticate_credentials(token)
@ -194,12 +199,12 @@ class TokenAuthentication(BaseAuthentication):
def authenticate_credentials(self, key): def authenticate_credentials(self, key):
model = self.get_model() model = self.get_model()
try: try:
token = model.objects.select_related('user').get(key=key) token = model.objects.select_related("user").get(key=key)
except model.DoesNotExist: except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.')) raise exceptions.AuthenticationFailed(_("Invalid token."))
if not token.user.is_active: if not token.user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (token.user, token) return (token.user, token)

View File

@ -1 +1 @@
default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig' default_app_config = "rest_framework.authtoken.apps.AuthTokenConfig"

View File

@ -4,9 +4,9 @@ from rest_framework.authtoken.models import Token
class TokenAdmin(admin.ModelAdmin): class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created') list_display = ("key", "user", "created")
fields = ('user',) fields = ("user",)
ordering = ('-created',) ordering = ("-created",)
admin.site.register(Token, TokenAdmin) admin.site.register(Token, TokenAdmin)

View File

@ -3,5 +3,5 @@ from django.utils.translation import ugettext_lazy as _
class AuthTokenConfig(AppConfig): class AuthTokenConfig(AppConfig):
name = 'rest_framework.authtoken' name = "rest_framework.authtoken"
verbose_name = _("Auth Token") verbose_name = _("Auth Token")

View File

@ -3,11 +3,12 @@ from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
UserModel = get_user_model() UserModel = get_user_model()
class Command(BaseCommand): class Command(BaseCommand):
help = 'Create DRF Token for a given user' help = "Create DRF Token for a given user"
def create_user_token(self, username, reset_token): def create_user_token(self, username, reset_token):
user = UserModel._default_manager.get_by_natural_key(username) user = UserModel._default_manager.get_by_natural_key(username)
@ -19,27 +20,27 @@ class Command(BaseCommand):
return token[0] return token[0]
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('username', type=str) parser.add_argument("username", type=str)
parser.add_argument( parser.add_argument(
'-r', "-r",
'--reset', "--reset",
action='store_true', action="store_true",
dest='reset_token', dest="reset_token",
default=False, default=False,
help='Reset existing User token and create a new one', help="Reset existing User token and create a new one",
) )
def handle(self, *args, **options): def handle(self, *args, **options):
username = options['username'] username = options["username"]
reset_token = options['reset_token'] reset_token = options["reset_token"]
try: try:
token = self.create_user_token(username, reset_token) token = self.create_user_token(username, reset_token)
except UserModel.DoesNotExist: except UserModel.DoesNotExist:
raise CommandError( raise CommandError(
'Cannot create the Token: user {0} does not exist'.format( "Cannot create the Token: user {0} does not exist".format(username)
username)
) )
self.stdout.write( self.stdout.write(
'Generated token {0} for user {1}'.format(token.key, username)) "Generated token {0} for user {1}".format(token.key, username)
)

View File

@ -7,20 +7,27 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Token', name="Token",
fields=[ fields=[
('key', models.CharField(primary_key=True, serialize=False, max_length=40)), (
('created', models.DateTimeField(auto_now_add=True)), "key",
('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token', on_delete=models.CASCADE)), models.CharField(primary_key=True, serialize=False, max_length=40),
),
("created", models.DateTimeField(auto_now_add=True)),
(
"user",
models.OneToOneField(
to=settings.AUTH_USER_MODEL,
related_name="auth_token",
on_delete=models.CASCADE,
),
),
], ],
options={ options={},
},
bases=(models.Model,), bases=(models.Model,),
), )
] ]

View File

@ -7,28 +7,33 @@ from django.db import migrations, models
class Migration(migrations.Migration): class Migration(migrations.Migration):
dependencies = [ dependencies = [("authtoken", "0001_initial")]
('authtoken', '0001_initial'),
]
operations = [ operations = [
migrations.AlterModelOptions( migrations.AlterModelOptions(
name='token', name="token",
options={'verbose_name_plural': 'Tokens', 'verbose_name': 'Token'}, options={"verbose_name_plural": "Tokens", "verbose_name": "Token"},
), ),
migrations.AlterField( migrations.AlterField(
model_name='token', model_name="token",
name='created', name="created",
field=models.DateTimeField(verbose_name='Created', auto_now_add=True), field=models.DateTimeField(verbose_name="Created", auto_now_add=True),
), ),
migrations.AlterField( migrations.AlterField(
model_name='token', model_name="token",
name='key', name="key",
field=models.CharField(verbose_name='Key', max_length=40, primary_key=True, serialize=False), field=models.CharField(
verbose_name="Key", max_length=40, primary_key=True, serialize=False
),
), ),
migrations.AlterField( migrations.AlterField(
model_name='token', model_name="token",
name='user', name="user",
field=models.OneToOneField(to=settings.AUTH_USER_MODEL, verbose_name='User', related_name='auth_token', on_delete=models.CASCADE), field=models.OneToOneField(
to=settings.AUTH_USER_MODEL,
verbose_name="User",
related_name="auth_token",
on_delete=models.CASCADE,
),
), ),
] ]

View File

@ -12,10 +12,13 @@ class Token(models.Model):
""" """
The default authorization token model. The default authorization token model.
""" """
key = models.CharField(_("Key"), max_length=40, primary_key=True) key = models.CharField(_("Key"), max_length=40, primary_key=True)
user = models.OneToOneField( user = models.OneToOneField(
settings.AUTH_USER_MODEL, related_name='auth_token', settings.AUTH_USER_MODEL,
on_delete=models.CASCADE, verbose_name=_("User") related_name="auth_token",
on_delete=models.CASCADE,
verbose_name=_("User"),
) )
created = models.DateTimeField(_("Created"), auto_now_add=True) created = models.DateTimeField(_("Created"), auto_now_add=True)
@ -25,7 +28,7 @@ class Token(models.Model):
# #
# Also see corresponding ticket: # Also see corresponding ticket:
# https://github.com/encode/django-rest-framework/issues/705 # https://github.com/encode/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS abstract = "rest_framework.authtoken" not in settings.INSTALLED_APPS
verbose_name = _("Token") verbose_name = _("Token")
verbose_name_plural = _("Tokens") verbose_name_plural = _("Tokens")

View File

@ -7,28 +7,29 @@ from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer): class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(label=_("Username")) username = serializers.CharField(label=_("Username"))
password = serializers.CharField( password = serializers.CharField(
label=_("Password"), label=_("Password"), style={"input_type": "password"}, trim_whitespace=False
style={'input_type': 'password'},
trim_whitespace=False
) )
def validate(self, attrs): def validate(self, attrs):
username = attrs.get('username') username = attrs.get("username")
password = attrs.get('password') password = attrs.get("password")
if username and password: if username and password:
user = authenticate(request=self.context.get('request'), user = authenticate(
username=username, password=password) request=self.context.get("request"),
username=username,
password=password,
)
# The authenticate call simply returns None for is_active=False # The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication # users. (Assuming the default ModelBackend authentication
# backend.) # backend.)
if not user: if not user:
msg = _('Unable to log in with provided credentials.') msg = _("Unable to log in with provided credentials.")
raise serializers.ValidationError(msg, code='authorization') raise serializers.ValidationError(msg, code="authorization")
else: else:
msg = _('Must include "username" and "password".') msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg, code='authorization') raise serializers.ValidationError(msg, code="authorization")
attrs['user'] = user attrs["user"] = user
return attrs return attrs

View File

@ -10,7 +10,7 @@ from rest_framework.views import APIView
class ObtainAuthToken(APIView): class ObtainAuthToken(APIView):
throttle_classes = () throttle_classes = ()
permission_classes = () permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser)
renderer_classes = (renderers.JSONRenderer,) renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer serializer_class = AuthTokenSerializer
if coreapi is not None and coreschema is not None: if coreapi is not None and coreschema is not None:
@ -19,7 +19,7 @@ class ObtainAuthToken(APIView):
coreapi.Field( coreapi.Field(
name="username", name="username",
required=True, required=True,
location='form', location="form",
schema=coreschema.String( schema=coreschema.String(
title="Username", title="Username",
description="Valid username for authentication", description="Valid username for authentication",
@ -28,7 +28,7 @@ class ObtainAuthToken(APIView):
coreapi.Field( coreapi.Field(
name="password", name="password",
required=True, required=True,
location='form', location="form",
schema=coreschema.String( schema=coreschema.String(
title="Password", title="Password",
description="Valid password for authentication", description="Valid password for authentication",
@ -39,12 +39,13 @@ class ObtainAuthToken(APIView):
) )
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data, serializer = self.serializer_class(
context={'request': request}) data=request.data, context={"request": request}
)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user'] user = serializer.validated_data["user"]
token, created = Token.objects.get_or_create(user=user) token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key}) return Response({"token": token.key})
obtain_auth_token = ObtainAuthToken.as_view() obtain_auth_token = ObtainAuthToken.as_view()

View File

@ -6,16 +6,17 @@ def pagination_system_check(app_configs, **kwargs):
errors = [] errors = []
# Use of default page size setting requires a default Paginator class # Use of default page size setting requires a default Paginator class
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS: if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
errors.append( errors.append(
Warning( Warning(
"You have specified a default PAGE_SIZE pagination rest_framework setting," "You have specified a default PAGE_SIZE pagination rest_framework setting,"
"without specifying also a DEFAULT_PAGINATION_CLASS.", "without specifying also a DEFAULT_PAGINATION_CLASS.",
hint="The default for DEFAULT_PAGINATION_CLASS is None. " hint="The default for DEFAULT_PAGINATION_CLASS is None. "
"In previous versions this was PageNumberPagination. " "In previous versions this was PageNumberPagination. "
"If you wish to define PAGE_SIZE globally whilst defining " "If you wish to define PAGE_SIZE globally whilst defining "
"pagination_class on a per-view basis you may silence this check.", "pagination_class on a per-view basis you may silence this check.",
id="rest_framework.W001" id="rest_framework.W001",
) )
) )
return errors return errors

View File

@ -12,18 +12,16 @@ from django.core import validators
from django.utils import six from django.utils import six
from django.views.generic import View from django.views.generic import View
try:
# Python 3
from collections.abc import Mapping, MutableMapping # noqa
except ImportError:
# Python 2.7
from collections import Mapping, MutableMapping # noqa
try: try:
from django.urls import ( # noqa # Python 3
URLPattern, from collections.abc import Mapping, MutableMapping # noqa
URLResolver, except ImportError:
) # Python 2.7
from collections import Mapping, MutableMapping # noqa
try:
from django.urls import URLPattern, URLResolver # noqa
except ImportError: except ImportError:
# Will be removed in Django 2.0 # Will be removed in Django 2.0
from django.urls import ( # noqa from django.urls import ( # noqa
@ -47,7 +45,7 @@ def get_original_route(urlpattern):
Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This
is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path(). is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path().
""" """
if hasattr(urlpattern, 'pattern'): if hasattr(urlpattern, "pattern"):
# Django 2.0 # Django 2.0
return str(urlpattern.pattern) return str(urlpattern.pattern)
else: else:
@ -60,7 +58,7 @@ def get_regex_pattern(urlpattern):
Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression, Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression,
unlike get_original_route above. unlike get_original_route above.
""" """
if hasattr(urlpattern, 'pattern'): if hasattr(urlpattern, "pattern"):
# Django 2.0 # Django 2.0
return urlpattern.pattern.regex.pattern return urlpattern.pattern.regex.pattern
else: else:
@ -69,9 +67,10 @@ def get_regex_pattern(urlpattern):
def is_route_pattern(urlpattern): def is_route_pattern(urlpattern):
if hasattr(urlpattern, 'pattern'): if hasattr(urlpattern, "pattern"):
# Django 2.0 # Django 2.0
from django.urls.resolvers import RoutePattern from django.urls.resolvers import RoutePattern
return isinstance(urlpattern.pattern, RoutePattern) return isinstance(urlpattern.pattern, RoutePattern)
else: else:
# Django < 2.0 # Django < 2.0
@ -82,6 +81,7 @@ def make_url_resolver(regex, urlpatterns):
try: try:
# Django 2.0 # Django 2.0
from django.urls.resolvers import RegexPattern from django.urls.resolvers import RegexPattern
return URLResolver(RegexPattern(regex), urlpatterns) return URLResolver(RegexPattern(regex), urlpatterns)
except ImportError: except ImportError:
@ -93,7 +93,7 @@ def unicode_repr(instance):
# Get the repr of an instance, but ensure it is a unicode string # Get the repr of an instance, but ensure it is a unicode string
# on both python 3 (already the case) and 2 (not the case). # on both python 3 (already the case) and 2 (not the case).
if six.PY2: if six.PY2:
return repr(instance).decode('utf-8') return repr(instance).decode("utf-8")
return repr(instance) return repr(instance)
@ -102,21 +102,21 @@ def unicode_to_repr(value):
# the Python version. We wrap all our `__repr__` implementations with # the Python version. We wrap all our `__repr__` implementations with
# this and then use unicode throughout internally. # this and then use unicode throughout internally.
if six.PY2: if six.PY2:
return value.encode('utf-8') return value.encode("utf-8")
return value return value
def unicode_http_header(value): def unicode_http_header(value):
# Coerce HTTP header value to unicode. # Coerce HTTP header value to unicode.
if isinstance(value, bytes): if isinstance(value, bytes):
return value.decode('iso-8859-1') return value.decode("iso-8859-1")
return value return value
def distinct(queryset, base): def distinct(queryset, base):
if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle": if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
# distinct analogue for Oracle users # distinct analogue for Oracle users
return base.filter(pk__in=set(queryset.values_list('pk', flat=True))) return base.filter(pk__in=set(queryset.values_list("pk", flat=True)))
return queryset.distinct() return queryset.distinct()
@ -172,27 +172,27 @@ def is_guardian_installed():
# Guardian 1.5.0, for Django 2.2 is NOT compatible with Python 2.7. # Guardian 1.5.0, for Django 2.2 is NOT compatible with Python 2.7.
# Remove when dropping PY2. # Remove when dropping PY2.
return False return False
return 'guardian' in settings.INSTALLED_APPS return "guardian" in settings.INSTALLED_APPS
# PATCH method is not implemented by Django # PATCH method is not implemented by Django
if 'patch' not in View.http_method_names: if "patch" not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch'] View.http_method_names = View.http_method_names + ["patch"]
# Markdown is optional # Markdown is optional
try: try:
import markdown import markdown
if markdown.version <= '2.2': if markdown.version <= "2.2":
HEADERID_EXT_PATH = 'headerid' HEADERID_EXT_PATH = "headerid"
LEVEL_PARAM = 'level' LEVEL_PARAM = "level"
elif markdown.version < '2.6': elif markdown.version < "2.6":
HEADERID_EXT_PATH = 'markdown.extensions.headerid' HEADERID_EXT_PATH = "markdown.extensions.headerid"
LEVEL_PARAM = 'level' LEVEL_PARAM = "level"
else: else:
HEADERID_EXT_PATH = 'markdown.extensions.toc' HEADERID_EXT_PATH = "markdown.extensions.toc"
LEVEL_PARAM = 'baselevel' LEVEL_PARAM = "baselevel"
def apply_markdown(text): def apply_markdown(text):
""" """
@ -200,16 +200,14 @@ try:
of '#' style headers to <h2>. of '#' style headers to <h2>.
""" """
extensions = [HEADERID_EXT_PATH] extensions = [HEADERID_EXT_PATH]
extension_configs = { extension_configs = {HEADERID_EXT_PATH: {LEVEL_PARAM: "2"}}
HEADERID_EXT_PATH: {
LEVEL_PARAM: '2'
}
}
md = markdown.Markdown( md = markdown.Markdown(
extensions=extensions, extension_configs=extension_configs extensions=extensions, extension_configs=extension_configs
) )
md_filter_add_syntax_highlight(md) md_filter_add_syntax_highlight(md)
return md.convert(text) return md.convert(text)
except ImportError: except ImportError:
apply_markdown = None apply_markdown = None
markdown = None markdown = None
@ -227,7 +225,8 @@ try:
def pygments_css(style): def pygments_css(style):
formatter = HtmlFormatter(style=style) formatter = HtmlFormatter(style=style)
return formatter.get_style_defs('.highlight') return formatter.get_style_defs(".highlight")
except ImportError: except ImportError:
pygments = None pygments = None
@ -238,6 +237,7 @@ except ImportError:
def pygments_css(style): def pygments_css(style):
return None return None
if markdown is not None and pygments is not None: if markdown is not None and pygments is not None:
# starting from this blogpost and modified to support current markdown extensions API # starting from this blogpost and modified to support current markdown extensions API
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/ # https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
@ -246,8 +246,7 @@ if markdown is not None and pygments is not None:
import re import re
class CodeBlockPreprocessor(Preprocessor): class CodeBlockPreprocessor(Preprocessor):
pattern = re.compile( pattern = re.compile(r"^\s*``` *([^\n]+)\n(.+?)^\s*```", re.M | re.S)
r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
formatter = HtmlFormatter() formatter = HtmlFormatter()
@ -257,17 +256,25 @@ if markdown is not None and pygments is not None:
lexer = get_lexer_by_name(m.group(1)) lexer = get_lexer_by_name(m.group(1))
except (ValueError, NameError): except (ValueError, NameError):
lexer = TextLexer() lexer = TextLexer()
code = m.group(2).replace('\t', ' ') code = m.group(2).replace("\t", " ")
code = pygments.highlight(code, lexer, self.formatter) code = pygments.highlight(code, lexer, self.formatter)
code = code.replace('\n\n', '\n&nbsp;\n').replace('\n', '<br />').replace('\\@', '@') code = (
return '\n\n%s\n\n' % code code.replace("\n\n", "\n&nbsp;\n")
.replace("\n", "<br />")
.replace("\\@", "@")
)
return "\n\n%s\n\n" % code
ret = self.pattern.sub(repl, "\n".join(lines)) ret = self.pattern.sub(repl, "\n".join(lines))
return ret.split("\n") return ret.split("\n")
def md_filter_add_syntax_highlight(md): def md_filter_add_syntax_highlight(md):
md.preprocessors.add('highlight', CodeBlockPreprocessor(), "_begin") md.preprocessors.add("highlight", CodeBlockPreprocessor(), "_begin")
return True return True
else: else:
def md_filter_add_syntax_highlight(md): def md_filter_add_syntax_highlight(md):
return False return False
@ -276,7 +283,8 @@ else:
try: try:
from django.urls import include, path, re_path, register_converter # noqa from django.urls import include, path, re_path, register_converter # noqa
except ImportError: except ImportError:
from django.conf.urls import include, url # noqa from django.conf.urls import include, url # noqa
path = None path = None
register_converter = None register_converter = None
re_path = url re_path = url
@ -285,13 +293,13 @@ except ImportError:
# `separators` argument to `json.dumps()` differs between 2.x and 3.x # `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767 # See: https://bugs.python.org/issue22767
if six.PY3: if six.PY3:
SHORT_SEPARATORS = (',', ':') SHORT_SEPARATORS = (",", ":")
LONG_SEPARATORS = (', ', ': ') LONG_SEPARATORS = (", ", ": ")
INDENT_SEPARATORS = (',', ': ') INDENT_SEPARATORS = (",", ": ")
else: else:
SHORT_SEPARATORS = (b',', b':') SHORT_SEPARATORS = (b",", b":")
LONG_SEPARATORS = (b', ', b': ') LONG_SEPARATORS = (b", ", b": ")
INDENT_SEPARATORS = (b',', b': ') INDENT_SEPARATORS = (b",", b": ")
class CustomValidatorMessage(object): class CustomValidatorMessage(object):
@ -303,7 +311,7 @@ class CustomValidatorMessage(object):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.message = kwargs.pop('message', self.message) self.message = kwargs.pop("message", self.message)
super(CustomValidatorMessage, self).__init__(*args, **kwargs) super(CustomValidatorMessage, self).__init__(*args, **kwargs)

View File

@ -23,14 +23,14 @@ def api_view(http_method_names=None):
Decorator that converts a function-based view into an APIView subclass. Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument. Takes a list of allowed methods for the view as an argument.
""" """
http_method_names = ['GET'] if (http_method_names is None) else http_method_names http_method_names = ["GET"] if (http_method_names is None) else http_method_names
def decorator(func): def decorator(func):
WrappedAPIView = type( WrappedAPIView = type(
six.PY3 and 'WrappedAPIView' or b'WrappedAPIView', six.PY3 and "WrappedAPIView" or b"WrappedAPIView",
(APIView,), (APIView,),
{'__doc__': func.__doc__} {"__doc__": func.__doc__},
) )
# Note, the above allows us to set the docstring. # Note, the above allows us to set the docstring.
@ -41,15 +41,20 @@ def api_view(http_method_names=None):
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names) # api_view applied without (method_names)
assert not(isinstance(http_method_names, types.FunctionType)), \ assert not (
'@api_view missing list of allowed HTTP methods' isinstance(http_method_names, types.FunctionType)
), "@api_view missing list of allowed HTTP methods"
# api_view applied with eg. string instead of list of strings # api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \ assert isinstance(http_method_names, (list, tuple)), (
'@api_view expected a list of strings, received %s' % type(http_method_names).__name__ "@api_view expected a list of strings, received %s"
% type(http_method_names).__name__
)
allowed_methods = set(http_method_names) | {'options'} allowed_methods = set(http_method_names) | {"options"}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] WrappedAPIView.http_method_names = [
method.lower() for method in allowed_methods
]
def handler(self, *args, **kwargs): def handler(self, *args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -60,23 +65,27 @@ def api_view(http_method_names=None):
WrappedAPIView.__name__ = func.__name__ WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__ WrappedAPIView.__module__ = func.__module__
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes', WrappedAPIView.renderer_classes = getattr(
APIView.renderer_classes) func, "renderer_classes", APIView.renderer_classes
)
WrappedAPIView.parser_classes = getattr(func, 'parser_classes', WrappedAPIView.parser_classes = getattr(
APIView.parser_classes) func, "parser_classes", APIView.parser_classes
)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', WrappedAPIView.authentication_classes = getattr(
APIView.authentication_classes) func, "authentication_classes", APIView.authentication_classes
)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', WrappedAPIView.throttle_classes = getattr(
APIView.throttle_classes) func, "throttle_classes", APIView.throttle_classes
)
WrappedAPIView.permission_classes = getattr(func, 'permission_classes', WrappedAPIView.permission_classes = getattr(
APIView.permission_classes) func, "permission_classes", APIView.permission_classes
)
WrappedAPIView.schema = getattr(func, 'schema', WrappedAPIView.schema = getattr(func, "schema", APIView.schema)
APIView.schema)
return WrappedAPIView.as_view() return WrappedAPIView.as_view()
@ -87,6 +96,7 @@ def renderer_classes(renderer_classes):
def decorator(func): def decorator(func):
func.renderer_classes = renderer_classes func.renderer_classes = renderer_classes
return func return func
return decorator return decorator
@ -94,6 +104,7 @@ def parser_classes(parser_classes):
def decorator(func): def decorator(func):
func.parser_classes = parser_classes func.parser_classes = parser_classes
return func return func
return decorator return decorator
@ -101,6 +112,7 @@ def authentication_classes(authentication_classes):
def decorator(func): def decorator(func):
func.authentication_classes = authentication_classes func.authentication_classes = authentication_classes
return func return func
return decorator return decorator
@ -108,6 +120,7 @@ def throttle_classes(throttle_classes):
def decorator(func): def decorator(func):
func.throttle_classes = throttle_classes func.throttle_classes = throttle_classes
return func return func
return decorator return decorator
@ -115,6 +128,7 @@ def permission_classes(permission_classes):
def decorator(func): def decorator(func):
func.permission_classes = permission_classes func.permission_classes = permission_classes
return func return func
return decorator return decorator
@ -122,6 +136,7 @@ def schema(view_inspector):
def decorator(func): def decorator(func):
func.schema = view_inspector func.schema = view_inspector
return func return func
return decorator return decorator
@ -132,15 +147,13 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
Set the `detail` boolean to determine if this action should apply to Set the `detail` boolean to determine if this action should apply to
instance/detail requests or collection/list requests. instance/detail requests or collection/list requests.
""" """
methods = ['get'] if (methods is None) else methods methods = ["get"] if (methods is None) else methods
methods = [method.lower() for method in methods] methods = [method.lower() for method in methods]
assert detail is not None, ( assert detail is not None, "@action() missing required argument: 'detail'"
"@action() missing required argument: 'detail'"
)
# name and suffix are mutually exclusive # name and suffix are mutually exclusive
if 'name' in kwargs and 'suffix' in kwargs: if "name" in kwargs and "suffix" in kwargs:
raise TypeError("`name` and `suffix` are mutually exclusive arguments.") raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
def decorator(func): def decorator(func):
@ -148,15 +161,16 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
func.detail = detail func.detail = detail
func.url_path = url_path if url_path else func.__name__ func.url_path = url_path if url_path else func.__name__
func.url_name = url_name if url_name else func.__name__.replace('_', '-') func.url_name = url_name if url_name else func.__name__.replace("_", "-")
func.kwargs = kwargs func.kwargs = kwargs
# Set descriptive arguments for viewsets # Set descriptive arguments for viewsets
if 'name' not in kwargs and 'suffix' not in kwargs: if "name" not in kwargs and "suffix" not in kwargs:
func.kwargs['name'] = pretty_name(func.__name__) func.kwargs["name"] = pretty_name(func.__name__)
func.kwargs['description'] = func.__doc__ or None func.kwargs["description"] = func.__doc__ or None
return func return func
return decorator return decorator
@ -184,39 +198,42 @@ class MethodMapper(dict):
self[method] = self.action.__name__ self[method] = self.action.__name__
def _map(self, method, func): def _map(self, method, func):
assert method not in self, ( assert method not in self, "Method '%s' has already been mapped to '.%s'." % (
"Method '%s' has already been mapped to '.%s'." % (method, self[method])) method,
self[method],
)
assert func.__name__ != self.action.__name__, ( assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You " "Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.") "cannot use the same method name for each mapping declaration."
)
self[method] = func.__name__ self[method] = func.__name__
return func return func
def get(self, func): def get(self, func):
return self._map('get', func) return self._map("get", func)
def post(self, func): def post(self, func):
return self._map('post', func) return self._map("post", func)
def put(self, func): def put(self, func):
return self._map('put', func) return self._map("put", func)
def patch(self, func): def patch(self, func):
return self._map('patch', func) return self._map("patch", func)
def delete(self, func): def delete(self, func):
return self._map('delete', func) return self._map("delete", func)
def head(self, func): def head(self, func):
return self._map('head', func) return self._map("head", func)
def options(self, func): def options(self, func):
return self._map('options', func) return self._map("options", func)
def trace(self, func): def trace(self, func):
return self._map('trace', func) return self._map("trace", func)
def detail_route(methods=None, **kwargs): def detail_route(methods=None, **kwargs):
@ -226,14 +243,16 @@ def detail_route(methods=None, **kwargs):
warnings.warn( warnings.warn(
"`detail_route` is deprecated and will be removed in 3.10 in favor of " "`detail_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=True)` instead.", "`action`, which accepts a `detail` bool. Use `@action(detail=True)` instead.",
RemovedInDRF310Warning, stacklevel=2 RemovedInDRF310Warning,
stacklevel=2,
) )
def decorator(func): def decorator(func):
func = action(methods, detail=True, **kwargs)(func) func = action(methods, detail=True, **kwargs)(func)
if 'url_name' not in kwargs: if "url_name" not in kwargs:
func.url_name = func.url_path.replace('_', '-') func.url_name = func.url_path.replace("_", "-")
return func return func
return decorator return decorator
@ -244,12 +263,14 @@ def list_route(methods=None, **kwargs):
warnings.warn( warnings.warn(
"`list_route` is deprecated and will be removed in 3.10 in favor of " "`list_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=False)` instead.", "`action`, which accepts a `detail` bool. Use `@action(detail=False)` instead.",
RemovedInDRF310Warning, stacklevel=2 RemovedInDRF310Warning,
stacklevel=2,
) )
def decorator(func): def decorator(func):
func = action(methods, detail=False, **kwargs)(func) func = action(methods, detail=False, **kwargs)(func)
if 'url_name' not in kwargs: if "url_name" not in kwargs:
func.url_name = func.url_path.replace('_', '-') func.url_name = func.url_path.replace("_", "-")
return func return func
return decorator return decorator

View File

@ -1,18 +1,25 @@
from django.conf.urls import include, url from django.conf.urls import include, url
from rest_framework.renderers import ( from rest_framework.renderers import (
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer CoreJSONRenderer,
DocumentationRenderer,
SchemaJSRenderer,
) )
from rest_framework.schemas import SchemaGenerator, get_schema_view from rest_framework.schemas import SchemaGenerator, get_schema_view
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
def get_docs_view( def get_docs_view(
title=None, description=None, schema_url=None, public=True, title=None,
patterns=None, generator_class=SchemaGenerator, description=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, schema_url=None,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, public=True,
renderer_classes=None): patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None,
):
if renderer_classes is None: if renderer_classes is None:
renderer_classes = [DocumentationRenderer, CoreJSONRenderer] renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
@ -31,10 +38,15 @@ def get_docs_view(
def get_schemajs_view( def get_schemajs_view(
title=None, description=None, schema_url=None, public=True, title=None,
patterns=None, generator_class=SchemaGenerator, description=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, schema_url=None,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): public=True,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
):
renderer_classes = [SchemaJSRenderer] renderer_classes = [SchemaJSRenderer]
return get_schema_view( return get_schema_view(
@ -51,11 +63,16 @@ def get_schemajs_view(
def include_docs_urls( def include_docs_urls(
title=None, description=None, schema_url=None, public=True, title=None,
patterns=None, generator_class=SchemaGenerator, description=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, schema_url=None,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, public=True,
renderer_classes=None): patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None,
):
docs_view = get_docs_view( docs_view = get_docs_view(
title=title, title=title,
description=description, description=description,
@ -78,7 +95,7 @@ def include_docs_urls(
permission_classes=permission_classes, permission_classes=permission_classes,
) )
urls = [ urls = [
url(r'^$', docs_view, name='docs-index'), url(r"^$", docs_view, name="docs-index"),
url(r'^schema.js$', schema_js_view, name='schema-js') url(r"^schema.js$", schema_js_view, name="schema-js"),
] ]
return include((urls, 'api-docs'), namespace='api-docs') return include((urls, "api-docs"), namespace="api-docs")

View File

@ -11,8 +11,7 @@ import math
from django.http import JsonResponse from django.http import JsonResponse
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _, ungettext
from django.utils.translation import ungettext
from rest_framework import status from rest_framework import status
from rest_framework.compat import unicode_to_repr from rest_framework.compat import unicode_to_repr
@ -25,23 +24,20 @@ def _get_error_details(data, default_code=None):
lazy translation strings or strings into `ErrorDetail`. lazy translation strings or strings into `ErrorDetail`.
""" """
if isinstance(data, list): if isinstance(data, list):
ret = [ ret = [_get_error_details(item, default_code) for item in data]
_get_error_details(item, default_code) for item in data
]
if isinstance(data, ReturnList): if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer) return ReturnList(ret, serializer=data.serializer)
return ret return ret
elif isinstance(data, dict): elif isinstance(data, dict):
ret = { ret = {
key: _get_error_details(value, default_code) key: _get_error_details(value, default_code) for key, value in data.items()
for key, value in data.items()
} }
if isinstance(data, ReturnDict): if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer) return ReturnDict(ret, serializer=data.serializer)
return ret return ret
text = force_text(data) text = force_text(data)
code = getattr(data, 'code', default_code) code = getattr(data, "code", default_code)
return ErrorDetail(text, code) return ErrorDetail(text, code)
@ -58,16 +54,14 @@ def _get_full_details(detail):
return [_get_full_details(item) for item in detail] return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict): elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()} return {key: _get_full_details(value) for key, value in detail.items()}
return { return {"message": detail, "code": detail.code}
'message': detail,
'code': detail.code
}
class ErrorDetail(six.text_type): class ErrorDetail(six.text_type):
""" """
A string-like object that can additionally have a code. A string-like object that can additionally have a code.
""" """
code = None code = None
def __new__(cls, string, code=None): def __new__(cls, string, code=None):
@ -86,10 +80,9 @@ class ErrorDetail(six.text_type):
return not self.__eq__(other) return not self.__eq__(other)
def __repr__(self): def __repr__(self):
return unicode_to_repr('ErrorDetail(string=%r, code=%r)' % ( return unicode_to_repr(
six.text_type(self), "ErrorDetail(string=%r, code=%r)" % (six.text_type(self), self.code)
self.code, )
))
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
@ -100,9 +93,10 @@ class APIException(Exception):
Base class for REST framework exceptions. Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties. Subclasses should provide `.status_code` and `.default_detail` properties.
""" """
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.') default_detail = _("A server error occurred.")
default_code = 'error' default_code = "error"
def __init__(self, detail=None, code=None): def __init__(self, detail=None, code=None):
if detail is None: if detail is None:
@ -139,10 +133,11 @@ class APIException(Exception):
# from rest_framework import serializers # from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid') # raise serializers.ValidationError('Value was invalid')
class ValidationError(APIException): class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid input.') default_detail = _("Invalid input.")
default_code = 'invalid' default_code = "invalid"
def __init__(self, detail=None, code=None): def __init__(self, detail=None, code=None):
if detail is None: if detail is None:
@ -160,38 +155,38 @@ class ValidationError(APIException):
class ParseError(APIException): class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Malformed request.') default_detail = _("Malformed request.")
default_code = 'parse_error' default_code = "parse_error"
class AuthenticationFailed(APIException): class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Incorrect authentication credentials.') default_detail = _("Incorrect authentication credentials.")
default_code = 'authentication_failed' default_code = "authentication_failed"
class NotAuthenticated(APIException): class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Authentication credentials were not provided.') default_detail = _("Authentication credentials were not provided.")
default_code = 'not_authenticated' default_code = "not_authenticated"
class PermissionDenied(APIException): class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN status_code = status.HTTP_403_FORBIDDEN
default_detail = _('You do not have permission to perform this action.') default_detail = _("You do not have permission to perform this action.")
default_code = 'permission_denied' default_code = "permission_denied"
class NotFound(APIException): class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND status_code = status.HTTP_404_NOT_FOUND
default_detail = _('Not found.') default_detail = _("Not found.")
default_code = 'not_found' default_code = "not_found"
class MethodNotAllowed(APIException): class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.') default_detail = _('Method "{method}" not allowed.')
default_code = 'method_not_allowed' default_code = "method_not_allowed"
def __init__(self, method, detail=None, code=None): def __init__(self, method, detail=None, code=None):
if detail is None: if detail is None:
@ -201,8 +196,8 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException): class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = _('Could not satisfy the request Accept header.') default_detail = _("Could not satisfy the request Accept header.")
default_code = 'not_acceptable' default_code = "not_acceptable"
def __init__(self, detail=None, code=None, available_renderers=None): def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers self.available_renderers = available_renderers
@ -212,7 +207,7 @@ class NotAcceptable(APIException):
class UnsupportedMediaType(APIException): class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.') default_detail = _('Unsupported media type "{media_type}" in request.')
default_code = 'unsupported_media_type' default_code = "unsupported_media_type"
def __init__(self, media_type, detail=None, code=None): def __init__(self, media_type, detail=None, code=None):
if detail is None: if detail is None:
@ -222,21 +217,28 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException): class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.') default_detail = _("Request was throttled.")
extra_detail_singular = 'Expected available in {wait} second.' extra_detail_singular = "Expected available in {wait} second."
extra_detail_plural = 'Expected available in {wait} seconds.' extra_detail_plural = "Expected available in {wait} seconds."
default_code = 'throttled' default_code = "throttled"
def __init__(self, wait=None, detail=None, code=None): def __init__(self, wait=None, detail=None, code=None):
if detail is None: if detail is None:
detail = force_text(self.default_detail) detail = force_text(self.default_detail)
if wait is not None: if wait is not None:
wait = math.ceil(wait) wait = math.ceil(wait)
detail = ' '.join(( detail = " ".join(
detail, (
force_text(ungettext(self.extra_detail_singular.format(wait=wait), detail,
self.extra_detail_plural.format(wait=wait), force_text(
wait)))) ungettext(
self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait,
)
),
)
)
self.wait = wait self.wait = wait
super(Throttled, self).__init__(detail, code) super(Throttled, self).__init__(detail, code)
@ -245,9 +247,7 @@ def server_error(request, *args, **kwargs):
""" """
Generic 500 error handler. Generic 500 error handler.
""" """
data = { data = {"error": "Server Error (500)"}
'error': 'Server Error (500)'
}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR) return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@ -255,7 +255,5 @@ def bad_request(request, exception, *args, **kwargs):
""" """
Generic 400 error handler. Generic 400 error handler.
""" """
data = { data = {"error": "Bad Request (400)"}
'error': 'Bad Request (400)'
}
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST) return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)

File diff suppressed because it is too large Load Diff

View File

@ -18,9 +18,7 @@ from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import RemovedInDRF310Warning from rest_framework import RemovedInDRF310Warning
from rest_framework.compat import ( from rest_framework.compat import coreapi, coreschema, distinct, is_guardian_installed
coreapi, coreschema, distinct, is_guardian_installed
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -36,23 +34,22 @@ class BaseFilterBackend(object):
raise NotImplementedError(".filter_queryset() must be overridden.") raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [] return []
class SearchFilter(BaseFilterBackend): class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search. # The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html' template = "rest_framework/filters/search.html"
lookup_prefixes = { lookup_prefixes = {"^": "istartswith", "=": "iexact", "@": "search", "$": "iregex"}
'^': 'istartswith', search_title = _("Search")
'=': 'iexact', search_description = _("A search term.")
'@': 'search',
'$': 'iregex',
}
search_title = _('Search')
search_description = _('A search term.')
def get_search_fields(self, view, request): def get_search_fields(self, view, request):
""" """
@ -60,22 +57,22 @@ class SearchFilter(BaseFilterBackend):
passed to this method. Sub-classes can override this method to passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content. dynamically change the search fields based on request content.
""" """
return getattr(view, 'search_fields', None) return getattr(view, "search_fields", None)
def get_search_terms(self, request): def get_search_terms(self, request):
""" """
Search terms are set by a ?search=... query parameter, Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited. and may be comma and/or whitespace delimited.
""" """
params = request.query_params.get(self.search_param, '') params = request.query_params.get(self.search_param, "")
return params.replace(',', ' ').split() return params.replace(",", " ").split()
def construct_search(self, field_name): def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0]) lookup = self.lookup_prefixes.get(field_name[0])
if lookup: if lookup:
field_name = field_name[1:] field_name = field_name[1:]
else: else:
lookup = 'icontains' lookup = "icontains"
return LOOKUP_SEP.join([field_name, lookup]) return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields): def must_call_distinct(self, queryset, search_fields):
@ -87,12 +84,15 @@ class SearchFilter(BaseFilterBackend):
if search_field[0] in self.lookup_prefixes: if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:] search_field = search_field[1:]
# Annotated fields do not need to be distinct # Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: if (
isinstance(queryset, models.QuerySet)
and search_field in queryset.query.annotations
):
return False return False
parts = search_field.split(LOOKUP_SEP) parts = search_field.split(LOOKUP_SEP)
for part in parts: for part in parts:
field = opts.get_field(part) field = opts.get_field(part)
if hasattr(field, 'get_path_info'): if hasattr(field, "get_path_info"):
# This field is a relation, update opts to follow the relation # This field is a relation, update opts to follow the relation
path_info = field.get_path_info() path_info = field.get_path_info()
opts = path_info[-1].to_opts opts = path_info[-1].to_opts
@ -117,8 +117,7 @@ class SearchFilter(BaseFilterBackend):
conditions = [] conditions = []
for search_term in search_terms: for search_term in search_terms:
queries = [ queries = [
models.Q(**{orm_lookup: search_term}) models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups
for orm_lookup in orm_lookups
] ]
conditions.append(reduce(operator.or_, queries)) conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions)) queryset = queryset.filter(reduce(operator.and_, conditions))
@ -132,30 +131,31 @@ class SearchFilter(BaseFilterBackend):
return queryset return queryset
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None): if not getattr(view, "search_fields", None):
return '' return ""
term = self.get_search_terms(request) term = self.get_search_terms(request)
term = term[0] if term else '' term = term[0] if term else ""
context = { context = {"param": self.search_param, "term": term}
'param': self.search_param,
'term': term
}
template = loader.get_template(self.template) template = loader.get_template(self.template)
return template.render(context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [ return [
coreapi.Field( coreapi.Field(
name=self.search_param, name=self.search_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.String( schema=coreschema.String(
title=force_text(self.search_title), title=force_text(self.search_title),
description=force_text(self.search_description) description=force_text(self.search_description),
) ),
) )
] ]
@ -164,9 +164,9 @@ class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering. # The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None ordering_fields = None
ordering_title = _('Ordering') ordering_title = _("Ordering")
ordering_description = _('Which field to use when ordering the results.') ordering_description = _("Which field to use when ordering the results.")
template = 'rest_framework/filters/ordering.html' template = "rest_framework/filters/ordering.html"
def get_ordering(self, request, queryset, view): def get_ordering(self, request, queryset, view):
""" """
@ -178,7 +178,7 @@ class OrderingFilter(BaseFilterBackend):
""" """
params = request.query_params.get(self.ordering_param) params = request.query_params.get(self.ordering_param)
if params: if params:
fields = [param.strip() for param in params.split(',')] fields = [param.strip() for param in params.split(",")]
ordering = self.remove_invalid_fields(queryset, fields, view, request) ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering: if ordering:
return ordering return ordering
@ -187,7 +187,7 @@ class OrderingFilter(BaseFilterBackend):
return self.get_default_ordering(view) return self.get_default_ordering(view)
def get_default_ordering(self, view): def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None) ordering = getattr(view, "ordering", None)
if isinstance(ordering, six.string_types): if isinstance(ordering, six.string_types):
return (ordering,) return (ordering,)
return ordering return ordering
@ -195,7 +195,7 @@ class OrderingFilter(BaseFilterBackend):
def get_default_valid_fields(self, queryset, view, context={}): def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default # If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view. # based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'): if hasattr(view, "get_serializer_class"):
try: try:
serializer_class = view.get_serializer_class() serializer_class = view.get_serializer_class()
except AssertionError: except AssertionError:
@ -203,7 +203,7 @@ class OrderingFilter(BaseFilterBackend):
# no serializer_class was found # no serializer_class was found
serializer_class = None serializer_class = None
else: else:
serializer_class = getattr(view, 'serializer_class', None) serializer_class = getattr(view, "serializer_class", None)
if serializer_class is None: if serializer_class is None:
msg = ( msg = (
@ -214,26 +214,26 @@ class OrderingFilter(BaseFilterBackend):
raise ImproperlyConfigured(msg % self.__class__.__name__) raise ImproperlyConfigured(msg % self.__class__.__name__)
return [ return [
(field.source.replace('.', '__') or field_name, field.label) (field.source.replace(".", "__") or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items() for field_name, field in serializer_class(context=context).fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*' if not getattr(field, "write_only", False) and not field.source == "*"
] ]
def get_valid_fields(self, queryset, view, context={}): def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) valid_fields = getattr(view, "ordering_fields", self.ordering_fields)
if valid_fields is None: if valid_fields is None:
# Default to allowing filtering on serializer fields # Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context) return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__': elif valid_fields == "__all__":
# View explicitly allows filtering on any model field # View explicitly allows filtering on any model field
valid_fields = [ valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields (field.name, field.verbose_name)
for field in queryset.model._meta.fields
] ]
valid_fields += [ valid_fields += [
(key, key.title().split('__')) (key, key.title().split("__")) for key in queryset.query.annotations
for key in queryset.query.annotations
] ]
else: else:
valid_fields = [ valid_fields = [
@ -244,8 +244,15 @@ class OrderingFilter(BaseFilterBackend):
return valid_fields return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request): def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] valid_fields = [
return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)] item[0]
for item in self.get_valid_fields(queryset, view, {"request": request})
]
return [
term
for term in fields
if term.lstrip("-") in valid_fields and ORDER_PATTERN.match(term)
]
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view) ordering = self.get_ordering(request, queryset, view)
@ -259,15 +266,11 @@ class OrderingFilter(BaseFilterBackend):
current = self.get_ordering(request, queryset, view) current = self.get_ordering(request, queryset, view)
current = None if not current else current[0] current = None if not current else current[0]
options = [] options = []
context = { context = {"request": request, "current": current, "param": self.ordering_param}
'request': request,
'current': current,
'param': self.ordering_param,
}
for key, label in self.get_valid_fields(queryset, view, context): for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending')))) options.append((key, "%s - %s" % (label, _("ascending"))))
options.append(('-' + key, '%s - %s' % (label, _('descending')))) options.append(("-" + key, "%s - %s" % (label, _("descending"))))
context['options'] = options context["options"] = options
return context return context
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
@ -276,17 +279,21 @@ class OrderingFilter(BaseFilterBackend):
return template.render(context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [ return [
coreapi.Field( coreapi.Field(
name=self.ordering_param, name=self.ordering_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.String( schema=coreschema.String(
title=force_text(self.ordering_title), title=force_text(self.ordering_title),
description=force_text(self.ordering_description) description=force_text(self.ordering_description),
) ),
) )
] ]
@ -296,15 +303,19 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
A filter backend that limits results to those where the requesting user A filter backend that limits results to those where the requesting user
has read object level permissions. has read object level permissions.
""" """
def __init__(self): def __init__(self):
warnings.warn( warnings.warn(
"`DjangoObjectPermissionsFilter` has been deprecated and moved to " "`DjangoObjectPermissionsFilter` has been deprecated and moved to "
"the 3rd-party django-rest-framework-guardian package.", "the 3rd-party django-rest-framework-guardian package.",
RemovedInDRF310Warning, stacklevel=2 RemovedInDRF310Warning,
stacklevel=2,
) )
assert is_guardian_installed(), 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed' assert (
is_guardian_installed()
), "Using DjangoObjectPermissionsFilter, but django-guardian is not installed"
perm_format = '%(app_label)s.view_%(model_name)s' perm_format = "%(app_label)s.view_%(model_name)s"
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
# We want to defer this import until run-time, rather than import-time. # We want to defer this import until run-time, rather than import-time.
@ -317,13 +328,13 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
user = request.user user = request.user
model_cls = queryset.model model_cls = queryset.model
kwargs = { kwargs = {
'app_label': model_cls._meta.app_label, "app_label": model_cls._meta.app_label,
'model_name': model_cls._meta.model_name "model_name": model_cls._meta.model_name,
} }
permission = self.perm_format % kwargs permission = self.perm_format % kwargs
if tuple(guardian_version) >= (1, 3): if tuple(guardian_version) >= (1, 3):
# Maintain behavior compatibility with versions prior to 1.3 # Maintain behavior compatibility with versions prior to 1.3
extra = {'accept_global_perms': False} extra = {"accept_global_perms": False}
else: else:
extra = {} extra = {}
return get_objects_for_user(user, permission, queryset, **extra) return get_objects_for_user(user, permission, queryset, **extra)

View File

@ -27,6 +27,7 @@ class GenericAPIView(views.APIView):
""" """
Base class for all other generic views. Base class for all other generic views.
""" """
# You'll need to either set these attributes, # You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`. # or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call # If you are overriding a view method, it is important that you call
@ -38,7 +39,7 @@ class GenericAPIView(views.APIView):
# If you want to use object lookups other than pk, set 'lookup_field'. # If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`. # For more complex lookup requirements override `get_object()`.
lookup_field = 'pk' lookup_field = "pk"
lookup_url_kwarg = None lookup_url_kwarg = None
# The filter backend classes to use for queryset filtering # The filter backend classes to use for queryset filtering
@ -64,8 +65,7 @@ class GenericAPIView(views.APIView):
""" """
assert self.queryset is not None, ( assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, " "'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method." "or override the `get_queryset()` method." % self.__class__.__name__
% self.__class__.__name__
) )
queryset = self.queryset queryset = self.queryset
@ -88,10 +88,10 @@ class GenericAPIView(views.APIView):
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, ( assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument ' "Expected view %s to be called with a URL keyword argument "
'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' % "attribute on the view correctly."
(self.__class__.__name__, lookup_url_kwarg) % (self.__class__.__name__, lookup_url_kwarg)
) )
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
@ -108,7 +108,7 @@ class GenericAPIView(views.APIView):
deserializing input, and for serializing output. deserializing input, and for serializing output.
""" """
serializer_class = self.get_serializer_class() serializer_class = self.get_serializer_class()
kwargs['context'] = self.get_serializer_context() kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs) return serializer_class(*args, **kwargs)
def get_serializer_class(self): def get_serializer_class(self):
@ -123,8 +123,7 @@ class GenericAPIView(views.APIView):
""" """
assert self.serializer_class is not None, ( assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, " "'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method." "or override the `get_serializer_class()` method." % self.__class__.__name__
% self.__class__.__name__
) )
return self.serializer_class return self.serializer_class
@ -133,11 +132,7 @@ class GenericAPIView(views.APIView):
""" """
Extra context provided to the serializer class. Extra context provided to the serializer class.
""" """
return { return {"request": self.request, "format": self.format_kwarg, "view": self}
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def filter_queryset(self, queryset): def filter_queryset(self, queryset):
""" """
@ -157,7 +152,7 @@ class GenericAPIView(views.APIView):
""" """
The paginator instance associated with the view, or `None`. The paginator instance associated with the view, or `None`.
""" """
if not hasattr(self, '_paginator'): if not hasattr(self, "_paginator"):
if self.pagination_class is None: if self.pagination_class is None:
self._paginator = None self._paginator = None
else: else:
@ -183,47 +178,48 @@ class GenericAPIView(views.APIView):
# Concrete view classes that provide method handlers # Concrete view classes that provide method handlers
# by composing the mixin classes with the base view. # by composing the mixin classes with the base view.
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView): class CreateAPIView(mixins.CreateModelMixin, GenericAPIView):
""" """
Concrete view for creating a model instance. Concrete view for creating a model instance.
""" """
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs) return self.create(request, *args, **kwargs)
class ListAPIView(mixins.ListModelMixin, class ListAPIView(mixins.ListModelMixin, GenericAPIView):
GenericAPIView):
""" """
Concrete view for listing a queryset. Concrete view for listing a queryset.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs) return self.list(request, *args, **kwargs)
class RetrieveAPIView(mixins.RetrieveModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin, GenericAPIView):
GenericAPIView):
""" """
Concrete view for retrieving a model instance. Concrete view for retrieving a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)
class DestroyAPIView(mixins.DestroyModelMixin, class DestroyAPIView(mixins.DestroyModelMixin, GenericAPIView):
GenericAPIView):
""" """
Concrete view for deleting a model instance. Concrete view for deleting a model instance.
""" """
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) return self.destroy(request, *args, **kwargs)
class UpdateAPIView(mixins.UpdateModelMixin, class UpdateAPIView(mixins.UpdateModelMixin, GenericAPIView):
GenericAPIView):
""" """
Concrete view for updating a model instance. Concrete view for updating a model instance.
""" """
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs) return self.update(request, *args, **kwargs)
@ -231,12 +227,11 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.partial_update(request, *args, **kwargs) return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin, class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, GenericAPIView):
mixins.CreateModelMixin,
GenericAPIView):
""" """
Concrete view for listing a queryset or creating a model instance. Concrete view for listing a queryset or creating a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs) return self.list(request, *args, **kwargs)
@ -244,12 +239,13 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs) return self.create(request, *args, **kwargs)
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateAPIView(
mixins.UpdateModelMixin, mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView
GenericAPIView): ):
""" """
Concrete view for retrieving, updating a model instance. Concrete view for retrieving, updating a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)
@ -260,12 +256,13 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.partial_update(request, *args, **kwargs) return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveDestroyAPIView(
mixins.DestroyModelMixin, mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView
GenericAPIView): ):
""" """
Concrete view for retrieving or deleting a model instance. Concrete view for retrieving or deleting a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)
@ -273,13 +270,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
return self.destroy(request, *args, **kwargs) return self.destroy(request, *args, **kwargs)
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(
mixins.UpdateModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.UpdateModelMixin,
GenericAPIView): mixins.DestroyModelMixin,
GenericAPIView,
):
""" """
Concrete view for retrieving, updating or deleting a model instance. Concrete view for retrieving, updating or deleting a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) return self.retrieve(request, *args, **kwargs)

View File

@ -2,7 +2,9 @@ from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi from rest_framework.compat import coreapi
from rest_framework.renderers import ( from rest_framework.renderers import (
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer CoreJSONRenderer,
JSONOpenAPIRenderer,
OpenAPIRenderer,
) )
from rest_framework.schemas.generators import SchemaGenerator from rest_framework.schemas.generators import SchemaGenerator
@ -11,31 +13,37 @@ class Command(BaseCommand):
help = "Generates configured API schema for project." help = "Generates configured API schema for project."
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default=None, type=str) parser.add_argument("--title", dest="title", default=None, type=str)
parser.add_argument('--url', dest="url", default=None, type=str) parser.add_argument("--url", dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str) parser.add_argument("--description", dest="description", default=None, type=str)
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) parser.add_argument(
"--format",
dest="format",
choices=["openapi", "openapi-json", "corejson"],
default="openapi",
type=str,
)
def handle(self, *args, **options): def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.' assert coreapi is not None, "coreapi must be installed."
generator = SchemaGenerator( generator = SchemaGenerator(
url=options['url'], url=options["url"],
title=options['title'], title=options["title"],
description=options['description'] description=options["description"],
) )
schema = generator.get_schema(request=None, public=True) schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format']) renderer = self.get_renderer(options["format"])
output = renderer.render(schema, renderer_context={}) output = renderer.render(schema, renderer_context={})
self.stdout.write(output.decode('utf-8')) self.stdout.write(output.decode("utf-8"))
def get_renderer(self, format): def get_renderer(self, format):
renderer_cls = { renderer_cls = {
'corejson': CoreJSONRenderer, "corejson": CoreJSONRenderer,
'openapi': OpenAPIRenderer, "openapi": OpenAPIRenderer,
'openapi-json': JSONOpenAPIRenderer, "openapi-json": JSONOpenAPIRenderer,
}[format] }[format]
return renderer_cls() return renderer_cls()

View File

@ -35,41 +35,46 @@ class SimpleMetadata(BaseMetadata):
There are not any formalized standards for `OPTIONS` responses There are not any formalized standards for `OPTIONS` responses
for us to base this on. for us to base this on.
""" """
label_lookup = ClassLookupDict({
serializers.Field: 'field', label_lookup = ClassLookupDict(
serializers.BooleanField: 'boolean', {
serializers.NullBooleanField: 'boolean', serializers.Field: "field",
serializers.CharField: 'string', serializers.BooleanField: "boolean",
serializers.UUIDField: 'string', serializers.NullBooleanField: "boolean",
serializers.URLField: 'url', serializers.CharField: "string",
serializers.EmailField: 'email', serializers.UUIDField: "string",
serializers.RegexField: 'regex', serializers.URLField: "url",
serializers.SlugField: 'slug', serializers.EmailField: "email",
serializers.IntegerField: 'integer', serializers.RegexField: "regex",
serializers.FloatField: 'float', serializers.SlugField: "slug",
serializers.DecimalField: 'decimal', serializers.IntegerField: "integer",
serializers.DateField: 'date', serializers.FloatField: "float",
serializers.DateTimeField: 'datetime', serializers.DecimalField: "decimal",
serializers.TimeField: 'time', serializers.DateField: "date",
serializers.ChoiceField: 'choice', serializers.DateTimeField: "datetime",
serializers.MultipleChoiceField: 'multiple choice', serializers.TimeField: "time",
serializers.FileField: 'file upload', serializers.ChoiceField: "choice",
serializers.ImageField: 'image upload', serializers.MultipleChoiceField: "multiple choice",
serializers.ListField: 'list', serializers.FileField: "file upload",
serializers.DictField: 'nested object', serializers.ImageField: "image upload",
serializers.Serializer: 'nested object', serializers.ListField: "list",
}) serializers.DictField: "nested object",
serializers.Serializer: "nested object",
}
)
def determine_metadata(self, request, view): def determine_metadata(self, request, view):
metadata = OrderedDict() metadata = OrderedDict()
metadata['name'] = view.get_view_name() metadata["name"] = view.get_view_name()
metadata['description'] = view.get_view_description() metadata["description"] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes] metadata["renders"] = [
metadata['parses'] = [parser.media_type for parser in view.parser_classes] renderer.media_type for renderer in view.renderer_classes
if hasattr(view, 'get_serializer'): ]
metadata["parses"] = [parser.media_type for parser in view.parser_classes]
if hasattr(view, "get_serializer"):
actions = self.determine_actions(request, view) actions = self.determine_actions(request, view)
if actions: if actions:
metadata['actions'] = actions metadata["actions"] = actions
return metadata return metadata
def determine_actions(self, request, view): def determine_actions(self, request, view):
@ -78,14 +83,14 @@ class SimpleMetadata(BaseMetadata):
the fields that are accepted for 'PUT' and 'POST' methods. the fields that are accepted for 'PUT' and 'POST' methods.
""" """
actions = {} actions = {}
for method in {'PUT', 'POST'} & set(view.allowed_methods): for method in {"PUT", "POST"} & set(view.allowed_methods):
view.request = clone_request(request, method) view.request = clone_request(request, method)
try: try:
# Test global permissions # Test global permissions
if hasattr(view, 'check_permissions'): if hasattr(view, "check_permissions"):
view.check_permissions(view.request) view.check_permissions(view.request)
# Test object permissions # Test object permissions
if method == 'PUT' and hasattr(view, 'get_object'): if method == "PUT" and hasattr(view, "get_object"):
view.get_object() view.get_object()
except (exceptions.APIException, PermissionDenied, Http404): except (exceptions.APIException, PermissionDenied, Http404):
pass pass
@ -104,15 +109,17 @@ class SimpleMetadata(BaseMetadata):
Given an instance of a serializer, return a dictionary of metadata Given an instance of a serializer, return a dictionary of metadata
about its fields. about its fields.
""" """
if hasattr(serializer, 'child'): if hasattr(serializer, "child"):
# If this is a `ListSerializer` then we want to examine the # If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead. # underlying child serializer instance instead.
serializer = serializer.child serializer = serializer.child
return OrderedDict([ return OrderedDict(
(field_name, self.get_field_info(field)) [
for field_name, field in serializer.fields.items() (field_name, self.get_field_info(field))
if not isinstance(field, serializers.HiddenField) for field_name, field in serializer.fields.items()
]) if not isinstance(field, serializers.HiddenField)
]
)
def get_field_info(self, field): def get_field_info(self, field):
""" """
@ -120,32 +127,40 @@ class SimpleMetadata(BaseMetadata):
of metadata about it. of metadata about it.
""" """
field_info = OrderedDict() field_info = OrderedDict()
field_info['type'] = self.label_lookup[field] field_info["type"] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False) field_info["required"] = getattr(field, "required", False)
attrs = [ attrs = [
'read_only', 'label', 'help_text', "read_only",
'min_length', 'max_length', "label",
'min_value', 'max_value' "help_text",
"min_length",
"max_length",
"min_value",
"max_value",
] ]
for attr in attrs: for attr in attrs:
value = getattr(field, attr, None) value = getattr(field, attr, None)
if value is not None and value != '': if value is not None and value != "":
field_info[attr] = force_text(value, strings_only=True) field_info[attr] = force_text(value, strings_only=True)
if getattr(field, 'child', None): if getattr(field, "child", None):
field_info['child'] = self.get_field_info(field.child) field_info["child"] = self.get_field_info(field.child)
elif getattr(field, 'fields', None): elif getattr(field, "fields", None):
field_info['children'] = self.get_serializer_info(field) field_info["children"] = self.get_serializer_info(field)
if (not field_info.get('read_only') and if (
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and not field_info.get("read_only")
hasattr(field, 'choices')): and not isinstance(
field_info['choices'] = [ field, (serializers.RelatedField, serializers.ManyRelatedField)
)
and hasattr(field, "choices")
):
field_info["choices"] = [
{ {
'value': choice_value, "value": choice_value,
'display_name': force_text(choice_name, strings_only=True) "display_name": force_text(choice_name, strings_only=True),
} }
for choice_value, choice_name in field.choices.items() for choice_value, choice_name in field.choices.items()
] ]

View File

@ -15,19 +15,22 @@ class CreateModelMixin(object):
""" """
Create a model instance. Create a model instance.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data) serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_create(serializer) self.perform_create(serializer)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
def perform_create(self, serializer): def perform_create(self, serializer):
serializer.save() serializer.save()
def get_success_headers(self, data): def get_success_headers(self, data):
try: try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])} return {"Location": str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError): except (TypeError, KeyError):
return {} return {}
@ -36,6 +39,7 @@ class ListModelMixin(object):
""" """
List a queryset. List a queryset.
""" """
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset()) queryset = self.filter_queryset(self.get_queryset())
@ -52,6 +56,7 @@ class RetrieveModelMixin(object):
""" """
Retrieve a model instance. Retrieve a model instance.
""" """
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance) serializer = self.get_serializer(instance)
@ -62,14 +67,15 @@ class UpdateModelMixin(object):
""" """
Update a model instance. Update a model instance.
""" """
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop("partial", False)
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_update(serializer) self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None): if getattr(instance, "_prefetched_objects_cache", None):
# If 'prefetch_related' has been applied to a queryset, we need to # If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance. # forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
@ -80,7 +86,7 @@ class UpdateModelMixin(object):
serializer.save() serializer.save()
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True kwargs["partial"] = True
return self.update(request, *args, **kwargs) return self.update(request, *args, **kwargs)
@ -88,6 +94,7 @@ class DestroyModelMixin(object):
""" """
Destroy a model instance. Destroy a model instance.
""" """
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
self.perform_destroy(instance) self.perform_destroy(instance)

View File

@ -9,16 +9,18 @@ from django.http import Http404
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import ( from rest_framework.utils.mediatypes import (
_MediaType, media_type_matches, order_by_precedence _MediaType,
media_type_matches,
order_by_precedence,
) )
class BaseContentNegotiation(object): class BaseContentNegotiation(object):
def select_parser(self, request, parsers): def select_parser(self, request, parsers):
raise NotImplementedError('.select_parser() must be implemented') raise NotImplementedError(".select_parser() must be implemented")
def select_renderer(self, request, renderers, format_suffix=None): def select_renderer(self, request, renderers, format_suffix=None):
raise NotImplementedError('.select_renderer() must be implemented') raise NotImplementedError(".select_renderer() must be implemented")
class DefaultContentNegotiation(BaseContentNegotiation): class DefaultContentNegotiation(BaseContentNegotiation):
@ -59,16 +61,20 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Return the most specific media type as accepted. # Return the most specific media type as accepted.
media_type_wrapper = _MediaType(media_type) media_type_wrapper = _MediaType(media_type)
if ( if (
_MediaType(renderer.media_type).precedence > _MediaType(renderer.media_type).precedence
media_type_wrapper.precedence > media_type_wrapper.precedence
): ):
# Eg client requests '*/*' # Eg client requests '*/*'
# Accepted media type is 'application/json' # Accepted media type is 'application/json'
full_media_type = ';'.join( full_media_type = ";".join(
(renderer.media_type,) + (renderer.media_type,)
tuple('{0}={1}'.format( + tuple(
key, value.decode(HTTP_HEADER_ENCODING)) "{0}={1}".format(
for key, value in media_type_wrapper.params.items())) key, value.decode(HTTP_HEADER_ENCODING)
)
for key, value in media_type_wrapper.params.items()
)
)
return renderer, full_media_type return renderer, full_media_type
else: else:
# Eg client requests 'application/json; indent=8' # Eg client requests 'application/json; indent=8'
@ -82,8 +88,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
If there is a '.json' style format suffix, filter the renderers If there is a '.json' style format suffix, filter the renderers
so that we only negotiation against those that accept that format. so that we only negotiation against those that accept that format.
""" """
renderers = [renderer for renderer in renderers renderers = [renderer for renderer in renderers if renderer.format == format]
if renderer.format == format]
if not renderers: if not renderers:
raise Http404 raise Http404
return renderers return renderers
@ -93,5 +98,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Given the incoming request, return a tokenized list of media Given the incoming request, return a tokenized list of media
type strings. type strings.
""" """
header = request.META.get('HTTP_ACCEPT', '*/*') header = request.META.get("HTTP_ACCEPT", "*/*")
return [token.strip() for token in header.split(',')] return [token.strip() for token in header.split(",")]

View File

@ -8,8 +8,7 @@ from __future__ import unicode_literals
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from django.core.paginator import InvalidPage from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
from django.core.paginator import Paginator as DjangoPaginator
from django.template import loader from django.template import loader
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
@ -83,10 +82,7 @@ def _get_displayed_page_numbers(current, final):
included.add(final - 2) included.add(final - 2)
# Now sort the page numbers and drop anything outside the limits. # Now sort the page numbers and drop anything outside the limits.
included = [ included = [idx for idx in sorted(list(included)) if 0 < idx <= final]
idx for idx in sorted(list(included))
if 0 < idx <= final
]
# Finally insert any `...` breaks # Finally insert any `...` breaks
if current > 4: if current > 4:
@ -110,7 +106,7 @@ def _get_page_links(page_numbers, current, url_func):
url=url_func(page_number), url=url_func(page_number),
number=page_number, number=page_number,
is_active=(page_number == current), is_active=(page_number == current),
is_break=False is_break=False,
) )
page_links.append(page_link) page_links.append(page_link)
return page_links return page_links
@ -121,14 +117,15 @@ def _reverse_ordering(ordering_tuple):
Given an order_by tuple such as `('-created', 'uuid')` reverse the Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`. ordering and return a new tuple, eg. `('created', '-uuid')`.
""" """
def invert(x): def invert(x):
return x[1:] if x.startswith('-') else '-' + x return x[1:] if x.startswith("-") else "-" + x
return tuple([invert(item) for item in ordering_tuple]) return tuple([invert(item) for item in ordering_tuple])
Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position']) Cursor = namedtuple("Cursor", ["offset", "reverse", "position"])
PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break']) PageLink = namedtuple("PageLink", ["url", "number", "is_active", "is_break"])
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True) PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
@ -137,19 +134,23 @@ class BasePagination(object):
display_page_controls = False display_page_controls = False
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
raise NotImplementedError('paginate_queryset() must be implemented.') raise NotImplementedError("paginate_queryset() must be implemented.")
def get_paginated_response(self, data): # pragma: no cover def get_paginated_response(self, data): # pragma: no cover
raise NotImplementedError('get_paginated_response() must be implemented.') raise NotImplementedError("get_paginated_response() must be implemented.")
def to_html(self): # pragma: no cover def to_html(self): # pragma: no cover
raise NotImplementedError('to_html() must be implemented to display page controls.') raise NotImplementedError(
"to_html() must be implemented to display page controls."
)
def get_results(self, data): def get_results(self, data):
return data['results'] return data["results"]
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
return [] return []
@ -161,6 +162,7 @@ class PageNumberPagination(BasePagination):
http://api.example.org/accounts/?page=4 http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100 http://api.example.org/accounts/?page=4&page_size=100
""" """
# The default page size. # The default page size.
# Defaults to `None`, meaning pagination is disabled. # Defaults to `None`, meaning pagination is disabled.
page_size = api_settings.PAGE_SIZE page_size = api_settings.PAGE_SIZE
@ -168,23 +170,23 @@ class PageNumberPagination(BasePagination):
django_paginator_class = DjangoPaginator django_paginator_class = DjangoPaginator
# Client can control the page using this query parameter. # Client can control the page using this query parameter.
page_query_param = 'page' page_query_param = "page"
page_query_description = _('A page number within the paginated result set.') page_query_description = _("A page number within the paginated result set.")
# Client can control the page size using this query parameter. # Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage. # Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None page_size_query_param = None
page_size_query_description = _('Number of results to return per page.') page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request. # Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set. # Only relevant if 'page_size_query_param' has also been set.
max_page_size = None max_page_size = None
last_page_strings = ('last',) last_page_strings = ("last",)
template = 'rest_framework/pagination/numbers.html' template = "rest_framework/pagination/numbers.html"
invalid_page_message = _('Invalid page.') invalid_page_message = _("Invalid page.")
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
""" """
@ -216,12 +218,16 @@ class PageNumberPagination(BasePagination):
return list(self.page) return list(self.page)
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response(
('count', self.page.paginator.count), OrderedDict(
('next', self.get_next_link()), [
('previous', self.get_previous_link()), ("count", self.page.paginator.count),
('results', data) ("next", self.get_next_link()),
])) ("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_page_size(self, request): def get_page_size(self, request):
if self.page_size_query_param: if self.page_size_query_param:
@ -229,7 +235,7 @@ class PageNumberPagination(BasePagination):
return _positive_int( return _positive_int(
request.query_params[self.page_size_query_param], request.query_params[self.page_size_query_param],
strict=True, strict=True,
cutoff=self.max_page_size cutoff=self.max_page_size,
) )
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
@ -267,9 +273,9 @@ class PageNumberPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url) page_links = _get_page_links(page_numbers, current, page_number_to_url)
return { return {
'previous_url': self.get_previous_link(), "previous_url": self.get_previous_link(),
'next_url': self.get_next_link(), "next_url": self.get_next_link(),
'page_links': page_links "page_links": page_links,
} }
def to_html(self): def to_html(self):
@ -278,17 +284,20 @@ class PageNumberPagination(BasePagination):
return template.render(context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
fields = [ fields = [
coreapi.Field( coreapi.Field(
name=self.page_query_param, name=self.page_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.Integer( schema=coreschema.Integer(
title='Page', title="Page", description=force_text(self.page_query_description)
description=force_text(self.page_query_description) ),
)
) )
] ]
if self.page_size_query_param is not None: if self.page_size_query_param is not None:
@ -296,11 +305,11 @@ class PageNumberPagination(BasePagination):
coreapi.Field( coreapi.Field(
name=self.page_size_query_param, name=self.page_size_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.Integer( schema=coreschema.Integer(
title='Page size', title="Page size",
description=force_text(self.page_size_query_description) description=force_text(self.page_size_query_description),
) ),
) )
) )
return fields return fields
@ -313,13 +322,14 @@ class LimitOffsetPagination(BasePagination):
http://api.example.org/accounts/?limit=100 http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100 http://api.example.org/accounts/?offset=400&limit=100
""" """
default_limit = api_settings.PAGE_SIZE default_limit = api_settings.PAGE_SIZE
limit_query_param = 'limit' limit_query_param = "limit"
limit_query_description = _('Number of results to return per page.') limit_query_description = _("Number of results to return per page.")
offset_query_param = 'offset' offset_query_param = "offset"
offset_query_description = _('The initial index from which to return the results.') offset_query_description = _("The initial index from which to return the results.")
max_limit = None max_limit = None
template = 'rest_framework/pagination/numbers.html' template = "rest_framework/pagination/numbers.html"
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
self.count = self.get_count(queryset) self.count = self.get_count(queryset)
@ -334,15 +344,19 @@ class LimitOffsetPagination(BasePagination):
if self.count == 0 or self.offset > self.count: if self.count == 0 or self.offset > self.count:
return [] return []
return list(queryset[self.offset:self.offset + self.limit]) return list(queryset[self.offset : self.offset + self.limit])
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response(
('count', self.count), OrderedDict(
('next', self.get_next_link()), [
('previous', self.get_previous_link()), ("count", self.count),
('results', data) ("next", self.get_next_link()),
])) ("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_limit(self, request): def get_limit(self, request):
if self.limit_query_param: if self.limit_query_param:
@ -350,7 +364,7 @@ class LimitOffsetPagination(BasePagination):
return _positive_int( return _positive_int(
request.query_params[self.limit_query_param], request.query_params[self.limit_query_param],
strict=True, strict=True,
cutoff=self.max_limit cutoff=self.max_limit,
) )
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
@ -359,9 +373,7 @@ class LimitOffsetPagination(BasePagination):
def get_offset(self, request): def get_offset(self, request):
try: try:
return _positive_int( return _positive_int(request.query_params[self.offset_query_param])
request.query_params[self.offset_query_param],
)
except (KeyError, ValueError): except (KeyError, ValueError):
return 0 return 0
@ -399,10 +411,9 @@ class LimitOffsetPagination(BasePagination):
# plus the number of pages up to the current offset. # plus the number of pages up to the current offset.
# When offset is not strictly divisible by the limit then we may # When offset is not strictly divisible by the limit then we may
# end up introducing an extra page as an artifact. # end up introducing an extra page as an artifact.
final = ( final = _divide_with_ceil(
_divide_with_ceil(self.count - self.offset, self.limit) + self.count - self.offset, self.limit
_divide_with_ceil(self.offset, self.limit) ) + _divide_with_ceil(self.offset, self.limit)
)
if final < 1: if final < 1:
final = 1 final = 1
@ -424,9 +435,9 @@ class LimitOffsetPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url) page_links = _get_page_links(page_numbers, current, page_number_to_url)
return { return {
'previous_url': self.get_previous_link(), "previous_url": self.get_previous_link(),
'next_url': self.get_next_link(), "next_url": self.get_next_link(),
'page_links': page_links "page_links": page_links,
} }
def to_html(self): def to_html(self):
@ -435,27 +446,30 @@ class LimitOffsetPagination(BasePagination):
return template.render(context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [ return [
coreapi.Field( coreapi.Field(
name=self.limit_query_param, name=self.limit_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.Integer( schema=coreschema.Integer(
title='Limit', title="Limit", description=force_text(self.limit_query_description)
description=force_text(self.limit_query_description) ),
)
), ),
coreapi.Field( coreapi.Field(
name=self.offset_query_param, name=self.offset_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.Integer( schema=coreschema.Integer(
title='Offset', title="Offset",
description=force_text(self.offset_query_description) description=force_text(self.offset_query_description),
) ),
) ),
] ]
def get_count(self, queryset): def get_count(self, queryset):
@ -474,17 +488,18 @@ class CursorPagination(BasePagination):
For an overview of the position/offset style we use, see this post: For an overview of the position/offset style we use, see this post:
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
""" """
cursor_query_param = 'cursor'
cursor_query_description = _('The pagination cursor value.') cursor_query_param = "cursor"
cursor_query_description = _("The pagination cursor value.")
page_size = api_settings.PAGE_SIZE page_size = api_settings.PAGE_SIZE
invalid_cursor_message = _('Invalid cursor') invalid_cursor_message = _("Invalid cursor")
ordering = '-created' ordering = "-created"
template = 'rest_framework/pagination/previous_and_next.html' template = "rest_framework/pagination/previous_and_next.html"
# Client can control the page size using this query parameter. # Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage. # Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None page_size_query_param = None
page_size_query_description = _('Number of results to return per page.') page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request. # Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set. # Only relevant if 'page_size_query_param' has also been set.
@ -519,27 +534,29 @@ class CursorPagination(BasePagination):
# If we have a cursor with a fixed position then filter by that. # If we have a cursor with a fixed position then filter by that.
if current_position is not None: if current_position is not None:
order = self.ordering[0] order = self.ordering[0]
is_reversed = order.startswith('-') is_reversed = order.startswith("-")
order_attr = order.lstrip('-') order_attr = order.lstrip("-")
# Test for: (cursor reversed) XOR (queryset reversed) # Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed: if self.cursor.reverse != is_reversed:
kwargs = {order_attr + '__lt': current_position} kwargs = {order_attr + "__lt": current_position}
else: else:
kwargs = {order_attr + '__gt': current_position} kwargs = {order_attr + "__gt": current_position}
queryset = queryset.filter(**kwargs) queryset = queryset.filter(**kwargs)
# If we have an offset cursor then offset the entire page by that amount. # If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a # We also always fetch an extra item in order to determine if there is a
# page following on from this one. # page following on from this one.
results = list(queryset[offset:offset + self.page_size + 1]) results = list(queryset[offset : offset + self.page_size + 1])
self.page = list(results[:self.page_size]) self.page = list(results[: self.page_size])
# Determine the position of the final item following the page. # Determine the position of the final item following the page.
if len(results) > len(self.page): if len(results) > len(self.page):
has_following_position = True has_following_position = True
following_position = self._get_position_from_instance(results[-1], self.ordering) following_position = self._get_position_from_instance(
results[-1], self.ordering
)
else: else:
has_following_position = False has_following_position = False
following_position = None following_position = None
@ -578,7 +595,7 @@ class CursorPagination(BasePagination):
return _positive_int( return _positive_int(
request.query_params[self.page_size_query_param], request.query_params[self.page_size_query_param],
strict=True, strict=True,
cutoff=self.max_page_size cutoff=self.max_page_size,
) )
except (KeyError, ValueError): except (KeyError, ValueError):
pass pass
@ -686,8 +703,9 @@ class CursorPagination(BasePagination):
Return a tuple of strings, that may be used in an `order_by` method. Return a tuple of strings, that may be used in an `order_by` method.
""" """
ordering_filters = [ ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', []) filter_cls
if hasattr(filter_cls, 'get_ordering') for filter_cls in getattr(view, "filter_backends", [])
if hasattr(filter_cls, "get_ordering")
] ]
if ordering_filters: if ordering_filters:
@ -697,29 +715,27 @@ class CursorPagination(BasePagination):
filter_instance = filter_cls() filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view) ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, ( assert ordering is not None, (
'Using cursor pagination, but filter class {filter_cls} ' "Using cursor pagination, but filter class {filter_cls} "
'returned a `None` ordering.'.format( "returned a `None` ordering.".format(filter_cls=filter_cls.__name__)
filter_cls=filter_cls.__name__
)
) )
else: else:
# The default case is to check for an `ordering` attribute # The default case is to check for an `ordering` attribute
# on this pagination instance. # on this pagination instance.
ordering = self.ordering ordering = self.ordering
assert ordering is not None, ( assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared ' "Using cursor pagination, but no ordering attribute was declared "
'on the pagination class.' "on the pagination class."
) )
assert '__' not in ordering, ( assert "__" not in ordering, (
'Cursor pagination does not support double underscore lookups ' "Cursor pagination does not support double underscore lookups "
'for orderings. Orderings should be an unchanging, unique or ' "for orderings. Orderings should be an unchanging, unique or "
'nearly-unique field on the model, such as "-created" or "pk".' 'nearly-unique field on the model, such as "-created" or "pk".'
) )
assert isinstance(ordering, (six.string_types, list, tuple)), ( assert isinstance(
'Invalid ordering. Expected string or tuple, but got {type}'.format( ordering, (six.string_types, list, tuple)
type=type(ordering).__name__ ), "Invalid ordering. Expected string or tuple, but got {type}".format(
) type=type(ordering).__name__
) )
if isinstance(ordering, six.string_types): if isinstance(ordering, six.string_types):
@ -736,16 +752,16 @@ class CursorPagination(BasePagination):
return None return None
try: try:
querystring = b64decode(encoded.encode('ascii')).decode('ascii') querystring = b64decode(encoded.encode("ascii")).decode("ascii")
tokens = urlparse.parse_qs(querystring, keep_blank_values=True) tokens = urlparse.parse_qs(querystring, keep_blank_values=True)
offset = tokens.get('o', ['0'])[0] offset = tokens.get("o", ["0"])[0]
offset = _positive_int(offset, cutoff=self.offset_cutoff) offset = _positive_int(offset, cutoff=self.offset_cutoff)
reverse = tokens.get('r', ['0'])[0] reverse = tokens.get("r", ["0"])[0]
reverse = bool(int(reverse)) reverse = bool(int(reverse))
position = tokens.get('p', [None])[0] position = tokens.get("p", [None])[0]
except (TypeError, ValueError): except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message) raise NotFound(self.invalid_cursor_message)
@ -757,18 +773,18 @@ class CursorPagination(BasePagination):
""" """
tokens = {} tokens = {}
if cursor.offset != 0: if cursor.offset != 0:
tokens['o'] = str(cursor.offset) tokens["o"] = str(cursor.offset)
if cursor.reverse: if cursor.reverse:
tokens['r'] = '1' tokens["r"] = "1"
if cursor.position is not None: if cursor.position is not None:
tokens['p'] = cursor.position tokens["p"] = cursor.position
querystring = urlparse.urlencode(tokens, doseq=True) querystring = urlparse.urlencode(tokens, doseq=True)
encoded = b64encode(querystring.encode('ascii')).decode('ascii') encoded = b64encode(querystring.encode("ascii")).decode("ascii")
return replace_query_param(self.base_url, self.cursor_query_param, encoded) return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def _get_position_from_instance(self, instance, ordering): def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-') field_name = ordering[0].lstrip("-")
if isinstance(instance, dict): if isinstance(instance, dict):
attr = instance[field_name] attr = instance[field_name]
else: else:
@ -776,16 +792,20 @@ class CursorPagination(BasePagination):
return six.text_type(attr) return six.text_type(attr)
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response(
('next', self.get_next_link()), OrderedDict(
('previous', self.get_previous_link()), [
('results', data) ("next", self.get_next_link()),
])) ("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_html_context(self): def get_html_context(self):
return { return {
'previous_url': self.get_previous_link(), "previous_url": self.get_previous_link(),
'next_url': self.get_next_link() "next_url": self.get_next_link(),
} }
def to_html(self): def to_html(self):
@ -794,17 +814,21 @@ class CursorPagination(BasePagination):
return template.render(context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert (
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
fields = [ fields = [
coreapi.Field( coreapi.Field(
name=self.cursor_query_param, name=self.cursor_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.String( schema=coreschema.String(
title='Cursor', title="Cursor",
description=force_text(self.cursor_query_description) description=force_text(self.cursor_query_description),
) ),
) )
] ]
if self.page_size_query_param is not None: if self.page_size_query_param is not None:
@ -812,11 +836,11 @@ class CursorPagination(BasePagination):
coreapi.Field( coreapi.Field(
name=self.page_size_query_param, name=self.page_size_query_param,
required=False, required=False,
location='query', location="query",
schema=coreschema.Integer( schema=coreschema.Integer(
title='Page size', title="Page size",
description=force_text(self.page_size_query_description) description=force_text(self.page_size_query_description),
) ),
) )
) )
return fields return fields

View File

@ -11,10 +11,12 @@ import codecs
from django.conf import settings from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict from django.http import QueryDict
from django.http.multipartparser import ChunkIter from django.http.multipartparser import (
from django.http.multipartparser import \ ChunkIter,
MultiPartParser as DjangoMultiPartParser MultiPartParser as DjangoMultiPartParser,
from django.http.multipartparser import MultiPartParserError, parse_header MultiPartParserError,
parse_header,
)
from django.utils import six from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.six.moves.urllib import parse as urlparse from django.utils.six.moves.urllib import parse as urlparse
@ -36,6 +38,7 @@ class BaseParser(object):
All parsers should extend `BaseParser`, specifying a `media_type` All parsers should extend `BaseParser`, specifying a `media_type`
attribute, and overriding the `.parse()` method. attribute, and overriding the `.parse()` method.
""" """
media_type = None media_type = None
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
@ -51,7 +54,8 @@ class JSONParser(BaseParser):
""" """
Parses JSON-serialized data. Parses JSON-serialized data.
""" """
media_type = 'application/json'
media_type = "application/json"
renderer_class = renderers.JSONRenderer renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON strict = api_settings.STRICT_JSON
@ -60,21 +64,22 @@ class JSONParser(BaseParser):
Parses the incoming bytestream as JSON and returns the resulting data. Parses the incoming bytestream as JSON and returns the resulting data.
""" """
parser_context = parser_context or {} parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
try: try:
decoded_stream = codecs.getreader(encoding)(stream) decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant) return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc: except ValueError as exc:
raise ParseError('JSON parse error - %s' % six.text_type(exc)) raise ParseError("JSON parse error - %s" % six.text_type(exc))
class FormParser(BaseParser): class FormParser(BaseParser):
""" """
Parser for form data. Parser for form data.
""" """
media_type = 'application/x-www-form-urlencoded'
media_type = "application/x-www-form-urlencoded"
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
""" """
@ -82,7 +87,7 @@ class FormParser(BaseParser):
and returns the resulting QueryDict. and returns the resulting QueryDict.
""" """
parser_context = parser_context or {} parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
data = QueryDict(stream.read(), encoding=encoding) data = QueryDict(stream.read(), encoding=encoding)
return data return data
@ -91,7 +96,8 @@ class MultiPartParser(BaseParser):
""" """
Parser for multipart form data, which may include file data. Parser for multipart form data, which may include file data.
""" """
media_type = 'multipart/form-data'
media_type = "multipart/form-data"
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
""" """
@ -102,10 +108,10 @@ class MultiPartParser(BaseParser):
`.files` will be a `QueryDict` containing all the form files. `.files` will be a `QueryDict` containing all the form files.
""" """
parser_context = parser_context or {} parser_context = parser_context or {}
request = parser_context['request'] request = parser_context["request"]
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META.copy() meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type meta["CONTENT_TYPE"] = media_type
upload_handlers = request.upload_handlers upload_handlers = request.upload_handlers
try: try:
@ -113,17 +119,18 @@ class MultiPartParser(BaseParser):
data, files = parser.parse() data, files = parser.parse()
return DataAndFiles(data, files) return DataAndFiles(data, files)
except MultiPartParserError as exc: except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % six.text_type(exc)) raise ParseError("Multipart form parse error - %s" % six.text_type(exc))
class FileUploadParser(BaseParser): class FileUploadParser(BaseParser):
""" """
Parser for file upload data. Parser for file upload data.
""" """
media_type = '*/*'
media_type = "*/*"
errors = { errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream', "unhandled": "FileUpload parse error - none of upload handlers can handle the stream",
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.', "no_filename": "Missing filename. Request should include a Content-Disposition header with a filename parameter.",
} }
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
@ -135,34 +142,32 @@ class FileUploadParser(BaseParser):
`.files` will be a `QueryDict` containing one 'file' element. `.files` will be a `QueryDict` containing one 'file' element.
""" """
parser_context = parser_context or {} parser_context = parser_context or {}
request = parser_context['request'] request = parser_context["request"]
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META meta = request.META
upload_handlers = request.upload_handlers upload_handlers = request.upload_handlers
filename = self.get_filename(stream, media_type, parser_context) filename = self.get_filename(stream, media_type, parser_context)
if not filename: if not filename:
raise ParseError(self.errors['no_filename']) raise ParseError(self.errors["no_filename"])
# Note that this code is extracted from Django's handling of # Note that this code is extracted from Django's handling of
# file uploads in MultiPartParser. # file uploads in MultiPartParser.
content_type = meta.get('HTTP_CONTENT_TYPE', content_type = meta.get("HTTP_CONTENT_TYPE", meta.get("CONTENT_TYPE", ""))
meta.get('CONTENT_TYPE', ''))
try: try:
content_length = int(meta.get('HTTP_CONTENT_LENGTH', content_length = int(
meta.get('CONTENT_LENGTH', 0))) meta.get("HTTP_CONTENT_LENGTH", meta.get("CONTENT_LENGTH", 0))
)
except (ValueError, TypeError): except (ValueError, TypeError):
content_length = None content_length = None
# See if the handler will want to take care of the parsing. # See if the handler will want to take care of the parsing.
for handler in upload_handlers: for handler in upload_handlers:
result = handler.handle_raw_input(stream, result = handler.handle_raw_input(
meta, stream, meta, content_length, None, encoding
content_length, )
None,
encoding)
if result is not None: if result is not None:
return DataAndFiles({}, {'file': result[1]}) return DataAndFiles({}, {"file": result[1]})
# This is the standard case. # This is the standard case.
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
@ -172,10 +177,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers): for index, handler in enumerate(upload_handlers):
try: try:
handler.new_file(None, filename, content_type, handler.new_file(None, filename, content_type, content_length, encoding)
content_length, encoding)
except StopFutureHandlers: except StopFutureHandlers:
upload_handlers = upload_handlers[:index + 1] upload_handlers = upload_handlers[: index + 1]
break break
for chunk in chunks: for chunk in chunks:
@ -189,9 +193,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers): for index, handler in enumerate(upload_handlers):
file_obj = handler.file_complete(counters[index]) file_obj = handler.file_complete(counters[index])
if file_obj is not None: if file_obj is not None:
return DataAndFiles({}, {'file': file_obj}) return DataAndFiles({}, {"file": file_obj})
raise ParseError(self.errors['unhandled']) raise ParseError(self.errors["unhandled"])
def get_filename(self, stream, media_type, parser_context): def get_filename(self, stream, media_type, parser_context):
""" """
@ -199,17 +203,17 @@ class FileUploadParser(BaseParser):
Then tries to parse Content-Disposition header. Then tries to parse Content-Disposition header.
""" """
try: try:
return parser_context['kwargs']['filename'] return parser_context["kwargs"]["filename"]
except KeyError: except KeyError:
pass pass
try: try:
meta = parser_context['request'].META meta = parser_context["request"].META
disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) disposition = parse_header(meta["HTTP_CONTENT_DISPOSITION"].encode("utf-8"))
filename_parm = disposition[1] filename_parm = disposition[1]
if 'filename*' in filename_parm: if "filename*" in filename_parm:
return self.get_encoded_filename(filename_parm) return self.get_encoded_filename(filename_parm)
return force_text(filename_parm['filename']) return force_text(filename_parm["filename"])
except (AttributeError, KeyError, ValueError): except (AttributeError, KeyError, ValueError):
pass pass
@ -218,10 +222,10 @@ class FileUploadParser(BaseParser):
Handle encoded filenames per RFC6266. See also: Handle encoded filenames per RFC6266. See also:
https://tools.ietf.org/html/rfc2231#section-4 https://tools.ietf.org/html/rfc2231#section-4
""" """
encoded_filename = force_text(filename_parm['filename*']) encoded_filename = force_text(filename_parm["filename*"])
try: try:
charset, lang, filename = encoded_filename.split('\'', 2) charset, lang, filename = encoded_filename.split("'", 2)
filename = urlparse.unquote(filename) filename = urlparse.unquote(filename)
except (ValueError, LookupError): except (ValueError, LookupError):
filename = force_text(filename_parm['filename']) filename = force_text(filename_parm["filename"])
return filename return filename

View File

@ -8,7 +8,8 @@ from django.utils import six
from rest_framework import exceptions from rest_framework import exceptions
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
SAFE_METHODS = ("GET", "HEAD", "OPTIONS")
class OperationHolderMixin: class OperationHolderMixin:
@ -56,16 +57,14 @@ class AND:
self.op2 = op2 self.op2 = op2
def has_permission(self, request, view): def has_permission(self, request, view):
return ( return self.op1.has_permission(request, view) and self.op2.has_permission(
self.op1.has_permission(request, view) and request, view
self.op2.has_permission(request, view)
) )
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
return ( return self.op1.has_object_permission(
self.op1.has_object_permission(request, view, obj) and request, view, obj
self.op2.has_object_permission(request, view, obj) ) and self.op2.has_object_permission(request, view, obj)
)
class OR: class OR:
@ -74,16 +73,14 @@ class OR:
self.op2 = op2 self.op2 = op2
def has_permission(self, request, view): def has_permission(self, request, view):
return ( return self.op1.has_permission(request, view) or self.op2.has_permission(
self.op1.has_permission(request, view) or request, view
self.op2.has_permission(request, view)
) )
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
return ( return self.op1.has_object_permission(
self.op1.has_object_permission(request, view, obj) or request, view, obj
self.op2.has_object_permission(request, view, obj) ) or self.op2.has_object_permission(request, view, obj)
)
class NOT: class NOT:
@ -157,9 +154,9 @@ class IsAuthenticatedOrReadOnly(BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):
return bool( return bool(
request.method in SAFE_METHODS or request.method in SAFE_METHODS
request.user and or request.user
request.user.is_authenticated and request.user.is_authenticated
) )
@ -179,13 +176,13 @@ class DjangoModelPermissions(BasePermission):
# Override this if you need to also provide 'view' permissions, # Override this if you need to also provide 'view' permissions,
# or if you want to provide custom permission codes. # or if you want to provide custom permission codes.
perms_map = { perms_map = {
'GET': [], "GET": [],
'OPTIONS': [], "OPTIONS": [],
'HEAD': [], "HEAD": [],
'POST': ['%(app_label)s.add_%(model_name)s'], "POST": ["%(app_label)s.add_%(model_name)s"],
'PUT': ['%(app_label)s.change_%(model_name)s'], "PUT": ["%(app_label)s.change_%(model_name)s"],
'PATCH': ['%(app_label)s.change_%(model_name)s'], "PATCH": ["%(app_label)s.change_%(model_name)s"],
'DELETE': ['%(app_label)s.delete_%(model_name)s'], "DELETE": ["%(app_label)s.delete_%(model_name)s"],
} }
authenticated_users_only = True authenticated_users_only = True
@ -196,8 +193,8 @@ class DjangoModelPermissions(BasePermission):
codes that the user is required to have. codes that the user is required to have.
""" """
kwargs = { kwargs = {
'app_label': model_cls._meta.app_label, "app_label": model_cls._meta.app_label,
'model_name': model_cls._meta.model_name "model_name": model_cls._meta.model_name,
} }
if method not in self.perms_map: if method not in self.perms_map:
@ -206,16 +203,19 @@ class DjangoModelPermissions(BasePermission):
return [perm % kwargs for perm in self.perms_map[method]] return [perm % kwargs for perm in self.perms_map[method]]
def _queryset(self, view): def _queryset(self, view):
assert hasattr(view, 'get_queryset') \ assert (
or getattr(view, 'queryset', None) is not None, ( hasattr(view, "get_queryset") or getattr(view, "queryset", None) is not None
'Cannot apply {} on a view that does not set ' ), (
'`.queryset` or have a `.get_queryset()` method.' "Cannot apply {} on a view that does not set "
).format(self.__class__.__name__) "`.queryset` or have a `.get_queryset()` method."
).format(
self.__class__.__name__
)
if hasattr(view, 'get_queryset'): if hasattr(view, "get_queryset"):
queryset = view.get_queryset() queryset = view.get_queryset()
assert queryset is not None, ( assert queryset is not None, "{}.get_queryset() returned None".format(
'{}.get_queryset() returned None'.format(view.__class__.__name__) view.__class__.__name__
) )
return queryset return queryset
return view.queryset return view.queryset
@ -223,11 +223,12 @@ class DjangoModelPermissions(BasePermission):
def has_permission(self, request, view): def has_permission(self, request, view):
# Workaround to ensure DjangoModelPermissions are not applied # Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter. # to the root view when using DefaultRouter.
if getattr(view, '_ignore_model_permissions', False): if getattr(view, "_ignore_model_permissions", False):
return True return True
if not request.user or ( if not request.user or (
not request.user.is_authenticated and self.authenticated_users_only): not request.user.is_authenticated and self.authenticated_users_only
):
return False return False
queryset = self._queryset(view) queryset = self._queryset(view)
@ -241,6 +242,7 @@ class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
Similar to DjangoModelPermissions, except that anonymous users are Similar to DjangoModelPermissions, except that anonymous users are
allowed read-only access. allowed read-only access.
""" """
authenticated_users_only = False authenticated_users_only = False
@ -255,20 +257,21 @@ class DjangoObjectPermissions(DjangoModelPermissions):
This permission can only be applied against view classes that This permission can only be applied against view classes that
provide a `.queryset` attribute. provide a `.queryset` attribute.
""" """
perms_map = { perms_map = {
'GET': [], "GET": [],
'OPTIONS': [], "OPTIONS": [],
'HEAD': [], "HEAD": [],
'POST': ['%(app_label)s.add_%(model_name)s'], "POST": ["%(app_label)s.add_%(model_name)s"],
'PUT': ['%(app_label)s.change_%(model_name)s'], "PUT": ["%(app_label)s.change_%(model_name)s"],
'PATCH': ['%(app_label)s.change_%(model_name)s'], "PATCH": ["%(app_label)s.change_%(model_name)s"],
'DELETE': ['%(app_label)s.delete_%(model_name)s'], "DELETE": ["%(app_label)s.delete_%(model_name)s"],
} }
def get_required_object_permissions(self, method, model_cls): def get_required_object_permissions(self, method, model_cls):
kwargs = { kwargs = {
'app_label': model_cls._meta.app_label, "app_label": model_cls._meta.app_label,
'model_name': model_cls._meta.model_name "model_name": model_cls._meta.model_name,
} }
if method not in self.perms_map: if method not in self.perms_map:
@ -294,7 +297,7 @@ class DjangoObjectPermissions(DjangoModelPermissions):
# to make another lookup. # to make another lookup.
raise Http404 raise Http404
read_perms = self.get_required_object_permissions('GET', model_cls) read_perms = self.get_required_object_permissions("GET", model_cls)
if not user.has_perms(read_perms, obj): if not user.has_perms(read_perms, obj):
raise Http404 raise Http404

View File

@ -9,14 +9,16 @@ from django.db.models import Manager
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils import six from django.utils import six
from django.utils.encoding import ( from django.utils.encoding import python_2_unicode_compatible, smart_text, uri_to_iri
python_2_unicode_compatible, smart_text, uri_to_iri
)
from django.utils.six.moves.urllib import parse as urlparse from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import ( from rest_framework.fields import (
Field, empty, get_attribute, is_simple_callable, iter_options Field,
empty,
get_attribute,
is_simple_callable,
iter_options,
) )
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -28,7 +30,7 @@ def method_overridden(method_name, klass, instance):
Determine if a method has been overridden. Determine if a method has been overridden.
""" """
method = getattr(klass, method_name) method = getattr(klass, method_name)
default_method = getattr(method, '__func__', method) # Python 3 compat default_method = getattr(method, "__func__", method) # Python 3 compat
return default_method is not getattr(instance, method_name).__func__ return default_method is not getattr(instance, method_name).__func__
@ -52,13 +54,14 @@ class Hyperlink(six.text_type):
We use this for hyperlinked URLs that may render as a named link We use this for hyperlinked URLs that may render as a named link
in some contexts, or render as a plain URL in others. in some contexts, or render as a plain URL in others.
""" """
def __new__(self, url, obj): def __new__(self, url, obj):
ret = six.text_type.__new__(self, url) ret = six.text_type.__new__(self, url)
ret.obj = obj ret.obj = obj
return ret return ret
def __getnewargs__(self): def __getnewargs__(self):
return(str(self), self.name,) return (str(self), self.name)
@property @property
def name(self): def name(self):
@ -77,6 +80,7 @@ class PKOnlyObject(object):
instance, but still want to return an object with a .pk attribute, instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance. in order to keep the same interface as a regular model instance.
""" """
def __init__(self, pk): def __init__(self, pk):
self.pk = pk self.pk = pk
@ -87,9 +91,19 @@ class PKOnlyObject(object):
# We assume that 'validators' are intended for the child serializer, # We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer. # rather than the parent serializer.
MANY_RELATION_KWARGS = ( MANY_RELATION_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source', "read_only",
'label', 'help_text', 'style', 'error_messages', 'allow_empty', "write_only",
'html_cutoff', 'html_cutoff_text' "required",
"default",
"initial",
"source",
"label",
"help_text",
"style",
"error_messages",
"allow_empty",
"html_cutoff",
"html_cutoff_text",
) )
@ -99,34 +113,34 @@ class RelatedField(Field):
html_cutoff_text = None html_cutoff_text = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', self.queryset) self.queryset = kwargs.pop("queryset", self.queryset)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None: if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings) cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings) self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop( self.html_cutoff_text = kwargs.pop(
'html_cutoff_text', "html_cutoff_text",
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
) )
if not method_overridden('get_queryset', RelatedField, self): if not method_overridden("get_queryset", RelatedField, self):
assert self.queryset is not None or kwargs.get('read_only', None), ( assert self.queryset is not None or kwargs.get("read_only", None), (
'Relational field must provide a `queryset` argument, ' "Relational field must provide a `queryset` argument, "
'override `get_queryset`, or set read_only=`True`.' "override `get_queryset`, or set read_only=`True`."
) )
assert not (self.queryset is not None and kwargs.get('read_only', None)), ( assert not (self.queryset is not None and kwargs.get("read_only", None)), (
'Relational fields should not provide a `queryset` argument, ' "Relational fields should not provide a `queryset` argument, "
'when setting read_only=`True`.' "when setting read_only=`True`."
) )
kwargs.pop('many', None) kwargs.pop("many", None)
kwargs.pop('allow_empty', None) kwargs.pop("allow_empty", None)
super(RelatedField, self).__init__(**kwargs) super(RelatedField, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create # We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set. # `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False): if kwargs.pop("many", False):
return cls.many_init(*args, **kwargs) return cls.many_init(*args, **kwargs)
return super(RelatedField, cls).__new__(cls, *args, **kwargs) return super(RelatedField, cls).__new__(cls, *args, **kwargs)
@ -147,7 +161,7 @@ class RelatedField(Field):
kwargs['child'] = cls() kwargs['child'] = cls()
return CustomManyRelatedField(*args, **kwargs) return CustomManyRelatedField(*args, **kwargs)
""" """
list_kwargs = {'child_relation': cls(*args, **kwargs)} list_kwargs = {"child_relation": cls(*args, **kwargs)}
for key in kwargs: for key in kwargs:
if key in MANY_RELATION_KWARGS: if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key] list_kwargs[key] = kwargs[key]
@ -155,7 +169,7 @@ class RelatedField(Field):
def run_validation(self, data=empty): def run_validation(self, data=empty):
# We force empty strings to None values for relational fields. # We force empty strings to None values for relational fields.
if data == '': if data == "":
data = None data = None
return super(RelatedField, self).run_validation(data) return super(RelatedField, self).run_validation(data)
@ -201,13 +215,12 @@ class RelatedField(Field):
if cutoff is not None: if cutoff is not None:
queryset = queryset[:cutoff] queryset = queryset[:cutoff]
return OrderedDict([ return OrderedDict(
( [
self.to_representation(item), (self.to_representation(item), self.display_value(item))
self.display_value(item) for item in queryset
) ]
for item in queryset )
])
@property @property
def choices(self): def choices(self):
@ -221,7 +234,7 @@ class RelatedField(Field):
return iter_options( return iter_options(
self.get_choices(cutoff=self.html_cutoff), self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff, cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text cutoff_text=self.html_cutoff_text,
) )
def display_value(self, instance): def display_value(self, instance):
@ -235,7 +248,7 @@ class StringRelatedField(RelatedField):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs['read_only'] = True kwargs["read_only"] = True
super(StringRelatedField, self).__init__(**kwargs) super(StringRelatedField, self).__init__(**kwargs)
def to_representation(self, value): def to_representation(self, value):
@ -244,13 +257,13 @@ class StringRelatedField(RelatedField):
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
default_error_messages = { default_error_messages = {
'required': _('This field is required.'), "required": _("This field is required."),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."),
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.pk_field = kwargs.pop('pk_field', None) self.pk_field = kwargs.pop("pk_field", None)
super(PrimaryKeyRelatedField, self).__init__(**kwargs) super(PrimaryKeyRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self): def use_pk_only_optimization(self):
@ -262,9 +275,9 @@ class PrimaryKeyRelatedField(RelatedField):
try: try:
return self.get_queryset().get(pk=data) return self.get_queryset().get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data) self.fail("does_not_exist", pk_value=data)
except (TypeError, ValueError): except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__) self.fail("incorrect_type", data_type=type(data).__name__)
def to_representation(self, value): def to_representation(self, value):
if self.pk_field is not None: if self.pk_field is not None:
@ -273,24 +286,26 @@ class PrimaryKeyRelatedField(RelatedField):
class HyperlinkedRelatedField(RelatedField): class HyperlinkedRelatedField(RelatedField):
lookup_field = 'pk' lookup_field = "pk"
view_name = None view_name = None
default_error_messages = { default_error_messages = {
'required': _('This field is required.'), "required": _("This field is required."),
'no_match': _('Invalid hyperlink - No URL match.'), "no_match": _("Invalid hyperlink - No URL match."),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'), "incorrect_match": _("Invalid hyperlink - Incorrect URL match."),
'does_not_exist': _('Invalid hyperlink - Object does not exist.'), "does_not_exist": _("Invalid hyperlink - Object does not exist."),
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'), "incorrect_type": _(
"Incorrect type. Expected URL string, received {data_type}."
),
} }
def __init__(self, view_name=None, **kwargs): def __init__(self, view_name=None, **kwargs):
if view_name is not None: if view_name is not None:
self.view_name = view_name self.view_name = view_name
assert self.view_name is not None, 'The `view_name` argument is required.' assert self.view_name is not None, "The `view_name` argument is required."
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) self.lookup_field = kwargs.pop("lookup_field", self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) self.lookup_url_kwarg = kwargs.pop("lookup_url_kwarg", self.lookup_field)
self.format = kwargs.pop('format', None) self.format = kwargs.pop("format", None)
# We include this simply for dependency injection in tests. # We include this simply for dependency injection in tests.
# We can't add it as a class attributes or it would expect an # We can't add it as a class attributes or it would expect an
@ -300,7 +315,7 @@ class HyperlinkedRelatedField(RelatedField):
super(HyperlinkedRelatedField, self).__init__(**kwargs) super(HyperlinkedRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self): def use_pk_only_optimization(self):
return self.lookup_field == 'pk' return self.lookup_field == "pk"
def get_object(self, view_name, view_args, view_kwargs): def get_object(self, view_name, view_args, view_kwargs):
""" """
@ -330,7 +345,7 @@ class HyperlinkedRelatedField(RelatedField):
attributes are not configured to correctly match the URL conf. attributes are not configured to correctly match the URL conf.
""" """
# Unsaved objects will not yet have a valid URL. # Unsaved objects will not yet have a valid URL.
if hasattr(obj, 'pk') and obj.pk in (None, ''): if hasattr(obj, "pk") and obj.pk in (None, ""):
return None return None
lookup_value = getattr(obj, self.lookup_field) lookup_value = getattr(obj, self.lookup_field)
@ -338,25 +353,25 @@ class HyperlinkedRelatedField(RelatedField):
return self.reverse(view_name, kwargs=kwargs, request=request, format=format) return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_internal_value(self, data): def to_internal_value(self, data):
request = self.context.get('request', None) request = self.context.get("request", None)
try: try:
http_prefix = data.startswith(('http:', 'https:')) http_prefix = data.startswith(("http:", "https:"))
except AttributeError: except AttributeError:
self.fail('incorrect_type', data_type=type(data).__name__) self.fail("incorrect_type", data_type=type(data).__name__)
if http_prefix: if http_prefix:
# If needed convert absolute URLs to relative path # If needed convert absolute URLs to relative path
data = urlparse.urlparse(data).path data = urlparse.urlparse(data).path
prefix = get_script_prefix() prefix = get_script_prefix()
if data.startswith(prefix): if data.startswith(prefix):
data = '/' + data[len(prefix):] data = "/" + data[len(prefix) :]
data = uri_to_iri(data) data = uri_to_iri(data)
try: try:
match = resolve(data) match = resolve(data)
except Resolver404: except Resolver404:
self.fail('no_match') self.fail("no_match")
try: try:
expected_viewname = request.versioning_scheme.get_versioned_viewname( expected_viewname = request.versioning_scheme.get_versioned_viewname(
@ -366,22 +381,22 @@ class HyperlinkedRelatedField(RelatedField):
expected_viewname = self.view_name expected_viewname = self.view_name
if match.view_name != expected_viewname: if match.view_name != expected_viewname:
self.fail('incorrect_match') self.fail("incorrect_match")
try: try:
return self.get_object(match.view_name, match.args, match.kwargs) return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError): except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
self.fail('does_not_exist') self.fail("does_not_exist")
def to_representation(self, value): def to_representation(self, value):
assert 'request' in self.context, ( assert "request" in self.context, (
"`%s` requires the request in the serializer" "`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating " " context. Add `context={'request': request}` when instantiating "
"the serializer." % self.__class__.__name__ "the serializer." % self.__class__.__name__
) )
request = self.context['request'] request = self.context["request"]
format = self.context.get('format', None) format = self.context.get("format", None)
# By default use whatever format is given for the current context # By default use whatever format is given for the current context
# unless the target is a different type to the source. # unless the target is a different type to the source.
@ -400,13 +415,13 @@ class HyperlinkedRelatedField(RelatedField):
url = self.get_url(value, self.view_name, request, format) url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch: except NoReverseMatch:
msg = ( msg = (
'Could not resolve URL for hyperlinked relationship using ' "Could not resolve URL for hyperlinked relationship using "
'view name "%s". You may have failed to include the related ' 'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the ' "model in your API, or incorrectly configured the "
'`lookup_field` attribute on this field.' "`lookup_field` attribute on this field."
) )
if value in ('', None): if value in ("", None):
value_string = {'': 'the empty string', None: 'None'}[value] value_string = {"": "the empty string", None: "None"}[value]
msg += ( msg += (
" WARNING: The value of the field on the model instance " " WARNING: The value of the field on the model instance "
"was %s, which may be why it didn't match any " "was %s, which may be why it didn't match any "
@ -429,9 +444,9 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
""" """
def __init__(self, view_name=None, **kwargs): def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.' assert view_name is not None, "The `view_name` argument is required."
kwargs['read_only'] = True kwargs["read_only"] = True
kwargs['source'] = '*' kwargs["source"] = "*"
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
def use_pk_only_optimization(self): def use_pk_only_optimization(self):
@ -445,13 +460,14 @@ class SlugRelatedField(RelatedField):
A read-write field that represents the target of the relationship A read-write field that represents the target of the relationship
by a unique 'slug' attribute. by a unique 'slug' attribute.
""" """
default_error_messages = { default_error_messages = {
'does_not_exist': _('Object with {slug_name}={value} does not exist.'), "does_not_exist": _("Object with {slug_name}={value} does not exist."),
'invalid': _('Invalid value.'), "invalid": _("Invalid value."),
} }
def __init__(self, slug_field=None, **kwargs): def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.' assert slug_field is not None, "The `slug_field` argument is required."
self.slug_field = slug_field self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs) super(SlugRelatedField, self).__init__(**kwargs)
@ -459,9 +475,11 @@ class SlugRelatedField(RelatedField):
try: try:
return self.get_queryset().get(**{self.slug_field: data}) return self.get_queryset().get(**{self.slug_field: data})
except ObjectDoesNotExist: except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data)) self.fail(
"does_not_exist", slug_name=self.slug_field, value=smart_text(data)
)
except (TypeError, ValueError): except (TypeError, ValueError):
self.fail('invalid') self.fail("invalid")
def to_representation(self, obj): def to_representation(self, obj):
return getattr(obj, self.slug_field) return getattr(obj, self.slug_field)
@ -479,31 +497,32 @@ class ManyRelatedField(Field):
You shouldn't generally need to be using this class directly yourself, You shouldn't generally need to be using this class directly yourself,
and should instead simply set 'many=True' on the relationship. and should instead simply set 'many=True' on the relationship.
""" """
initial = [] initial = []
default_empty_html = [] default_empty_html = []
default_error_messages = { default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'), "not_a_list": _('Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.') "empty": _("This list may not be empty."),
} }
html_cutoff = None html_cutoff = None
html_cutoff_text = None html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs): def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation self.child_relation = child_relation
self.allow_empty = kwargs.pop('allow_empty', True) self.allow_empty = kwargs.pop("allow_empty", True)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None: if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings) cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings) self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop( self.html_cutoff_text = kwargs.pop(
'html_cutoff_text', "html_cutoff_text",
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
) )
assert child_relation is not None, '`child_relation` is a required argument.' assert child_relation is not None, "`child_relation` is a required argument."
super(ManyRelatedField, self).__init__(*args, **kwargs) super(ManyRelatedField, self).__init__(*args, **kwargs)
self.child_relation.bind(field_name='', parent=self) self.child_relation.bind(field_name="", parent=self)
def get_value(self, dictionary): def get_value(self, dictionary):
# We override the default field access in order to support # We override the default field access in order to support
@ -511,36 +530,30 @@ class ManyRelatedField(Field):
if html.is_html_input(dictionary): if html.is_html_input(dictionary):
# Don't return [] if the update is partial # Don't return [] if the update is partial
if self.field_name not in dictionary: if self.field_name not in dictionary:
if getattr(self.root, 'partial', False): if getattr(self.root, "partial", False):
return empty return empty
return dictionary.getlist(self.field_name) return dictionary.getlist(self.field_name)
return dictionary.get(self.field_name, empty) return dictionary.get(self.field_name, empty)
def to_internal_value(self, data): def to_internal_value(self, data):
if isinstance(data, six.text_type) or not hasattr(data, '__iter__'): if isinstance(data, six.text_type) or not hasattr(data, "__iter__"):
self.fail('not_a_list', input_type=type(data).__name__) self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0: if not self.allow_empty and len(data) == 0:
self.fail('empty') self.fail("empty")
return [ return [self.child_relation.to_internal_value(item) for item in data]
self.child_relation.to_internal_value(item)
for item in data
]
def get_attribute(self, instance): def get_attribute(self, instance):
# Can't have any relationships if not created # Can't have any relationships if not created
if hasattr(instance, 'pk') and instance.pk is None: if hasattr(instance, "pk") and instance.pk is None:
return [] return []
relationship = get_attribute(instance, self.source_attrs) relationship = get_attribute(instance, self.source_attrs)
return relationship.all() if hasattr(relationship, 'all') else relationship return relationship.all() if hasattr(relationship, "all") else relationship
def to_representation(self, iterable): def to_representation(self, iterable):
return [ return [self.child_relation.to_representation(value) for value in iterable]
self.child_relation.to_representation(value)
for value in iterable
]
def get_choices(self, cutoff=None): def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff) return self.child_relation.get_choices(cutoff)
@ -557,5 +570,5 @@ class ManyRelatedField(Field):
return iter_options( return iter_options(
self.get_choices(cutoff=self.html_cutoff), self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff, cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text cutoff_text=self.html_cutoff_text,
) )

File diff suppressed because it is too large Load Diff

View File

@ -30,8 +30,10 @@ def is_form_media_type(media_type):
Return True if the media type is a valid form media type. Return True if the media type is a valid form media type.
""" """
base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING)) base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING))
return (base_media_type == 'application/x-www-form-urlencoded' or return (
base_media_type == 'multipart/form-data') base_media_type == "application/x-www-form-urlencoded"
or base_media_type == "multipart/form-data"
)
class override_method(object): class override_method(object):
@ -49,12 +51,12 @@ class override_method(object):
self.view = view self.view = view
self.request = request self.request = request
self.method = method self.method = method
self.action = getattr(view, 'action', None) self.action = getattr(view, "action", None)
def __enter__(self): def __enter__(self):
self.view.request = clone_request(self.request, self.method) self.view.request = clone_request(self.request, self.method)
# For viewsets we also set the `.action` attribute. # For viewsets we also set the `.action` attribute.
action_map = getattr(self.view, 'action_map', {}) action_map = getattr(self.view, "action_map", {})
self.view.action = action_map.get(self.method.lower()) self.view.action = action_map.get(self.method.lower())
return self.view.request return self.view.request
@ -86,6 +88,7 @@ class Empty(object):
Placeholder for unset attributes. Placeholder for unset attributes.
Cannot use `None`, as that may be a valid value. Cannot use `None`, as that may be a valid value.
""" """
pass pass
@ -98,30 +101,32 @@ def clone_request(request, method):
Internal helper method to clone a request, replacing with a different Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods. HTTP method. Used for checking permissions against other methods.
""" """
ret = Request(request=request._request, ret = Request(
parsers=request.parsers, request=request._request,
authenticators=request.authenticators, parsers=request.parsers,
negotiator=request.negotiator, authenticators=request.authenticators,
parser_context=request.parser_context) negotiator=request.negotiator,
parser_context=request.parser_context,
)
ret._data = request._data ret._data = request._data
ret._files = request._files ret._files = request._files
ret._full_data = request._full_data ret._full_data = request._full_data
ret._content_type = request._content_type ret._content_type = request._content_type
ret._stream = request._stream ret._stream = request._stream
ret.method = method ret.method = method
if hasattr(request, '_user'): if hasattr(request, "_user"):
ret._user = request._user ret._user = request._user
if hasattr(request, '_auth'): if hasattr(request, "_auth"):
ret._auth = request._auth ret._auth = request._auth
if hasattr(request, '_authenticator'): if hasattr(request, "_authenticator"):
ret._authenticator = request._authenticator ret._authenticator = request._authenticator
if hasattr(request, 'accepted_renderer'): if hasattr(request, "accepted_renderer"):
ret.accepted_renderer = request.accepted_renderer ret.accepted_renderer = request.accepted_renderer
if hasattr(request, 'accepted_media_type'): if hasattr(request, "accepted_media_type"):
ret.accepted_media_type = request.accepted_media_type ret.accepted_media_type = request.accepted_media_type
if hasattr(request, 'version'): if hasattr(request, "version"):
ret.version = request.version ret.version = request.version
if hasattr(request, 'versioning_scheme'): if hasattr(request, "versioning_scheme"):
ret.versioning_scheme = request.versioning_scheme ret.versioning_scheme = request.versioning_scheme
return ret return ret
@ -152,12 +157,19 @@ class Request(object):
authenticating the request's user. authenticating the request's user.
""" """
def __init__(self, request, parsers=None, authenticators=None, def __init__(
negotiator=None, parser_context=None): self,
request,
parsers=None,
authenticators=None,
negotiator=None,
parser_context=None,
):
assert isinstance(request, HttpRequest), ( assert isinstance(request, HttpRequest), (
'The `request` argument must be an instance of ' "The `request` argument must be an instance of "
'`django.http.HttpRequest`, not `{}.{}`.' "`django.http.HttpRequest`, not `{}.{}`.".format(
.format(request.__class__.__module__, request.__class__.__name__) request.__class__.__module__, request.__class__.__name__
)
) )
self._request = request self._request = request
@ -173,11 +185,11 @@ class Request(object):
if self.parser_context is None: if self.parser_context is None:
self.parser_context = {} self.parser_context = {}
self.parser_context['request'] = self self.parser_context["request"] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET self.parser_context["encoding"] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None) force_user = getattr(request, "_force_auth_user", None)
force_token = getattr(request, '_force_auth_token', None) force_token = getattr(request, "_force_auth_token", None)
if force_user is not None or force_token is not None: if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token) forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,) self.authenticators = (forced_auth,)
@ -188,14 +200,14 @@ class Request(object):
@property @property
def content_type(self): def content_type(self):
meta = self._request.META meta = self._request.META
return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) return meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
@property @property
def stream(self): def stream(self):
""" """
Returns an object that may be used to stream the request content. Returns an object that may be used to stream the request content.
""" """
if not _hasattr(self, '_stream'): if not _hasattr(self, "_stream"):
self._load_stream() self._load_stream()
return self._stream return self._stream
@ -208,7 +220,7 @@ class Request(object):
@property @property
def data(self): def data(self):
if not _hasattr(self, '_full_data'): if not _hasattr(self, "_full_data"):
self._load_data_and_files() self._load_data_and_files()
return self._full_data return self._full_data
@ -218,7 +230,7 @@ class Request(object):
Returns the user associated with the current request, as authenticated Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request. by the authentication classes provided to the request.
""" """
if not hasattr(self, '_user'): if not hasattr(self, "_user"):
with wrap_attributeerrors(): with wrap_attributeerrors():
self._authenticate() self._authenticate()
return self._user return self._user
@ -242,7 +254,7 @@ class Request(object):
Returns any non-user authentication information associated with the Returns any non-user authentication information associated with the
request, such as an authentication token. request, such as an authentication token.
""" """
if not hasattr(self, '_auth'): if not hasattr(self, "_auth"):
with wrap_attributeerrors(): with wrap_attributeerrors():
self._authenticate() self._authenticate()
return self._auth return self._auth
@ -262,7 +274,7 @@ class Request(object):
Return the instance of the authentication instance class that was used Return the instance of the authentication instance class that was used
to authenticate the request, or `None`. to authenticate the request, or `None`.
""" """
if not hasattr(self, '_authenticator'): if not hasattr(self, "_authenticator"):
with wrap_attributeerrors(): with wrap_attributeerrors():
self._authenticate() self._authenticate()
return self._authenticator return self._authenticator
@ -271,7 +283,7 @@ class Request(object):
""" """
Parses the request content into `self.data`. Parses the request content into `self.data`.
""" """
if not _hasattr(self, '_data'): if not _hasattr(self, "_data"):
self._data, self._files = self._parse() self._data, self._files = self._parse()
if self._files: if self._files:
self._full_data = self._data.copy() self._full_data = self._data.copy()
@ -292,7 +304,7 @@ class Request(object):
meta = self._request.META meta = self._request.META
try: try:
content_length = int( content_length = int(
meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)) meta.get("CONTENT_LENGTH", meta.get("HTTP_CONTENT_LENGTH", 0))
) )
except (ValueError, TypeError): except (ValueError, TypeError):
content_length = 0 content_length = 0
@ -308,10 +320,7 @@ class Request(object):
""" """
Return True if this requests supports parsing form data. Return True if this requests supports parsing form data.
""" """
form_media = ( form_media = ("application/x-www-form-urlencoded", "multipart/form-data")
'application/x-www-form-urlencoded',
'multipart/form-data'
)
return any([parser.media_type in form_media for parser in self.parsers]) return any([parser.media_type in form_media for parser in self.parsers])
def _parse(self): def _parse(self):
@ -324,7 +333,7 @@ class Request(object):
try: try:
stream = self.stream stream = self.stream
except RawPostDataException: except RawPostDataException:
if not hasattr(self._request, '_post'): if not hasattr(self._request, "_post"):
raise raise
# If request.POST has been accessed in middleware, and a method='POST' # If request.POST has been accessed in middleware, and a method='POST'
# request was made with 'multipart/form-data', then the request stream # request was made with 'multipart/form-data', then the request stream
@ -335,7 +344,7 @@ class Request(object):
if stream is None or media_type is None: if stream is None or media_type is None:
if media_type and is_form_media_type(media_type): if media_type and is_form_media_type(media_type):
empty_data = QueryDict('', encoding=self._request._encoding) empty_data = QueryDict("", encoding=self._request._encoding)
else: else:
empty_data = {} empty_data = {}
empty_files = MultiValueDict() empty_files = MultiValueDict()
@ -353,7 +362,7 @@ class Request(object):
# re-raise. Ensures we don't simply repeat the error when # re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when # attempting to render the browsable renderer response, or when
# logging the request or similar. # logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding) self._data = QueryDict("", encoding=self._request._encoding)
self._files = MultiValueDict() self._files = MultiValueDict()
self._full_data = self._data self._full_data = self._data
raise raise
@ -416,33 +425,33 @@ class Request(object):
@property @property
def DATA(self): def DATA(self):
raise NotImplementedError( raise NotImplementedError(
'`request.DATA` has been deprecated in favor of `request.data` ' "`request.DATA` has been deprecated in favor of `request.data` "
'since version 3.0, and has been fully removed as of version 3.2.' "since version 3.0, and has been fully removed as of version 3.2."
) )
@property @property
def POST(self): def POST(self):
# Ensure that request.POST uses our request parsing. # Ensure that request.POST uses our request parsing.
if not _hasattr(self, '_data'): if not _hasattr(self, "_data"):
self._load_data_and_files() self._load_data_and_files()
if is_form_media_type(self.content_type): if is_form_media_type(self.content_type):
return self._data return self._data
return QueryDict('', encoding=self._request._encoding) return QueryDict("", encoding=self._request._encoding)
@property @property
def FILES(self): def FILES(self):
# Leave this one alone for backwards compat with Django's request.FILES # Leave this one alone for backwards compat with Django's request.FILES
# Different from the other two cases, which are not valid property # Different from the other two cases, which are not valid property
# names on the WSGIRequest class. # names on the WSGIRequest class.
if not _hasattr(self, '_files'): if not _hasattr(self, "_files"):
self._load_data_and_files() self._load_data_and_files()
return self._files return self._files
@property @property
def QUERY_PARAMS(self): def QUERY_PARAMS(self):
raise NotImplementedError( raise NotImplementedError(
'`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` ' "`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` "
'since version 3.0, and has been fully removed as of version 3.2.' "since version 3.0, and has been fully removed as of version 3.2."
) )
def force_plaintext_errors(self, value): def force_plaintext_errors(self, value):

View File

@ -19,9 +19,15 @@ class Response(SimpleTemplateResponse):
arbitrary media types. arbitrary media types.
""" """
def __init__(self, data=None, status=None, def __init__(
template_name=None, headers=None, self,
exception=False, content_type=None): data=None,
status=None,
template_name=None,
headers=None,
exception=False,
content_type=None,
):
""" """
Alters the init arguments slightly. Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'. For example, drop 'template_name', and instead use 'data'.
@ -33,9 +39,9 @@ class Response(SimpleTemplateResponse):
if isinstance(data, Serializer): if isinstance(data, Serializer):
msg = ( msg = (
'You passed a Serializer instance as data, but ' "You passed a Serializer instance as data, but "
'probably meant to pass serialized `.data` or ' "probably meant to pass serialized `.data` or "
'`.error`. representation.' "`.error`. representation."
) )
raise AssertionError(msg) raise AssertionError(msg)
@ -50,14 +56,14 @@ class Response(SimpleTemplateResponse):
@property @property
def rendered_content(self): def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None) renderer = getattr(self, "accepted_renderer", None)
accepted_media_type = getattr(self, 'accepted_media_type', None) accepted_media_type = getattr(self, "accepted_media_type", None)
context = getattr(self, 'renderer_context', None) context = getattr(self, "renderer_context", None)
assert renderer, ".accepted_renderer not set on Response" assert renderer, ".accepted_renderer not set on Response"
assert accepted_media_type, ".accepted_media_type not set on Response" assert accepted_media_type, ".accepted_media_type not set on Response"
assert context is not None, ".renderer_context not set on Response" assert context is not None, ".renderer_context not set on Response"
context['response'] = self context["response"] = self
media_type = renderer.media_type media_type = renderer.media_type
charset = renderer.charset charset = renderer.charset
@ -67,18 +73,17 @@ class Response(SimpleTemplateResponse):
content_type = "{0}; charset={1}".format(media_type, charset) content_type = "{0}; charset={1}".format(media_type, charset)
elif content_type is None: elif content_type is None:
content_type = media_type content_type = media_type
self['Content-Type'] = content_type self["Content-Type"] = content_type
ret = renderer.render(self.data, accepted_media_type, context) ret = renderer.render(self.data, accepted_media_type, context)
if isinstance(ret, six.text_type): if isinstance(ret, six.text_type):
assert charset, ( assert charset, (
'renderer returned unicode, and did not specify ' "renderer returned unicode, and did not specify " "a charset value."
'a charset value.'
) )
return bytes(ret.encode(charset)) return bytes(ret.encode(charset))
if not ret: if not ret:
del self['Content-Type'] del self["Content-Type"]
return ret return ret
@ -88,7 +93,7 @@ class Response(SimpleTemplateResponse):
Returns reason text corresponding to our HTTP response status code. Returns reason text corresponding to our HTTP response status code.
Provided for convenience. Provided for convenience.
""" """
return responses.get(self.status_code, '') return responses.get(self.status_code, "")
def __getstate__(self): def __getstate__(self):
""" """
@ -96,10 +101,15 @@ class Response(SimpleTemplateResponse):
""" """
state = super(Response, self).__getstate__() state = super(Response, self).__getstate__()
for key in ( for key in (
'accepted_renderer', 'renderer_context', 'resolver_match', "accepted_renderer",
'client', 'request', 'json', 'wsgi_request' "renderer_context",
"resolver_match",
"client",
"request",
"json",
"wsgi_request",
): ):
if key in state: if key in state:
del state[key] del state[key]
state['_closable_objects'] = [] state["_closable_objects"] = []
return state return state

View File

@ -3,8 +3,7 @@ Provide urlresolver functions that return fully qualified URLs or view names
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.urls import NoReverseMatch from django.urls import NoReverseMatch, reverse as django_reverse
from django.urls import reverse as django_reverse
from django.utils import six from django.utils import six
from django.utils.functional import lazy from django.utils.functional import lazy
@ -20,9 +19,7 @@ def preserve_builtin_query_params(url, request=None):
if request is None: if request is None:
return url return url
overrides = [ overrides = [api_settings.URL_FORMAT_OVERRIDE]
api_settings.URL_FORMAT_OVERRIDE,
]
for param in overrides: for param in overrides:
if param and (param in request.GET): if param and (param in request.GET):
@ -38,7 +35,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
to the versioning scheme instance, so that the resulting URL to the versioning scheme instance, so that the resulting URL
can be modified if needed. can be modified if needed.
""" """
scheme = getattr(request, 'versioning_scheme', None) scheme = getattr(request, "versioning_scheme", None)
if scheme is not None: if scheme is not None:
try: try:
url = scheme.reverse(viewname, args, kwargs, request, format, **extra) url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
@ -59,7 +56,7 @@ def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extr
""" """
if format is not None: if format is not None:
kwargs = kwargs or {} kwargs = kwargs or {}
kwargs['format'] = format kwargs["format"] = format
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra) url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request: if request:
return request.build_absolute_uri(url) return request.build_absolute_uri(url)

View File

@ -25,9 +25,7 @@ from django.urls import NoReverseMatch
from django.utils import six from django.utils import six
from django.utils.deprecation import RenameMethodsBase from django.utils.deprecation import RenameMethodsBase
from rest_framework import ( from rest_framework import RemovedInDRF310Warning, RemovedInDRF311Warning, views
RemovedInDRF310Warning, RemovedInDRF311Warning, views
)
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator from rest_framework.schemas import SchemaGenerator
@ -35,8 +33,9 @@ from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs']) Route = namedtuple("Route", ["url", "mapping", "name", "detail", "initkwargs"])
DynamicRoute = namedtuple("DynamicRoute", ["url", "name", "detail", "initkwargs"])
class DynamicDetailRoute(object): class DynamicDetailRoute(object):
@ -45,7 +44,8 @@ class DynamicDetailRoute(object):
"`DynamicDetailRoute` is deprecated and will be removed in 3.10 " "`DynamicDetailRoute` is deprecated and will be removed in 3.10 "
"in favor of `DynamicRoute`, which accepts a `detail` boolean. Use " "in favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, True, initkwargs)` instead.", "`DynamicRoute(url, name, True, initkwargs)` instead.",
RemovedInDRF310Warning, stacklevel=2 RemovedInDRF310Warning,
stacklevel=2,
) )
return DynamicRoute(url, name, True, initkwargs) return DynamicRoute(url, name, True, initkwargs)
@ -56,7 +56,8 @@ class DynamicListRoute(object):
"`DynamicListRoute` is deprecated and will be removed in 3.10 in " "`DynamicListRoute` is deprecated and will be removed in 3.10 in "
"favor of `DynamicRoute`, which accepts a `detail` boolean. Use " "favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, False, initkwargs)` instead.", "`DynamicRoute(url, name, False, initkwargs)` instead.",
RemovedInDRF310Warning, stacklevel=2 RemovedInDRF310Warning,
stacklevel=2,
) )
return DynamicRoute(url, name, False, initkwargs) return DynamicRoute(url, name, False, initkwargs)
@ -65,8 +66,8 @@ def escape_curly_brackets(url_path):
""" """
Double brackets in regex of url_path for escape string formatting Double brackets in regex of url_path for escape string formatting
""" """
if ('{' and '}') in url_path: if ("{" and "}") in url_path:
url_path = url_path.replace('{', '{{').replace('}', '}}') url_path = url_path.replace("{", "{{").replace("}", "}}")
return url_path return url_path
@ -79,7 +80,7 @@ def flatten(list_of_lists):
class RenameRouterMethods(RenameMethodsBase): class RenameRouterMethods(RenameMethodsBase):
renamed_methods = ( renamed_methods = (
('get_default_base_name', 'get_default_basename', RemovedInDRF311Warning), ("get_default_base_name", "get_default_basename", RemovedInDRF311Warning),
) )
@ -92,8 +93,9 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
msg = "The `base_name` argument is pending deprecation in favor of `basename`." msg = "The `base_name` argument is pending deprecation in favor of `basename`."
warnings.warn(msg, RemovedInDRF311Warning, 2) warnings.warn(msg, RemovedInDRF311Warning, 2)
assert not (basename and base_name), ( assert not (
"Do not provide both the `basename` and `base_name` arguments.") basename and base_name
), "Do not provide both the `basename` and `base_name` arguments."
if basename is None: if basename is None:
basename = base_name basename = base_name
@ -103,7 +105,7 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
self.registry.append((prefix, viewset, basename)) self.registry.append((prefix, viewset, basename))
# invalidate the urls cache # invalidate the urls cache
if hasattr(self, '_urls'): if hasattr(self, "_urls"):
del self._urls del self._urls
def get_default_basename(self, viewset): def get_default_basename(self, viewset):
@ -111,17 +113,17 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
If `basename` is not specified, attempt to automatically determine If `basename` is not specified, attempt to automatically determine
it from the viewset. it from the viewset.
""" """
raise NotImplementedError('get_default_basename must be overridden') raise NotImplementedError("get_default_basename must be overridden")
def get_urls(self): def get_urls(self):
""" """
Return a list of URL patterns, given the registered viewsets. Return a list of URL patterns, given the registered viewsets.
""" """
raise NotImplementedError('get_urls must be overridden') raise NotImplementedError("get_urls must be overridden")
@property @property
def urls(self): def urls(self):
if not hasattr(self, '_urls'): if not hasattr(self, "_urls"):
self._urls = self.get_urls() self._urls = self.get_urls()
return self._urls return self._urls
@ -131,48 +133,45 @@ class SimpleRouter(BaseRouter):
routes = [ routes = [
# List route. # List route.
Route( Route(
url=r'^{prefix}{trailing_slash}$', url=r"^{prefix}{trailing_slash}$",
mapping={ mapping={"get": "list", "post": "create"},
'get': 'list', name="{basename}-list",
'post': 'create'
},
name='{basename}-list',
detail=False, detail=False,
initkwargs={'suffix': 'List'} initkwargs={"suffix": "List"},
), ),
# Dynamically generated list routes. Generated using # Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset. # @action(detail=False) decorator on methods of the viewset.
DynamicRoute( DynamicRoute(
url=r'^{prefix}/{url_path}{trailing_slash}$', url=r"^{prefix}/{url_path}{trailing_slash}$",
name='{basename}-{url_name}', name="{basename}-{url_name}",
detail=False, detail=False,
initkwargs={} initkwargs={},
), ),
# Detail route. # Detail route.
Route( Route(
url=r'^{prefix}/{lookup}{trailing_slash}$', url=r"^{prefix}/{lookup}{trailing_slash}$",
mapping={ mapping={
'get': 'retrieve', "get": "retrieve",
'put': 'update', "put": "update",
'patch': 'partial_update', "patch": "partial_update",
'delete': 'destroy' "delete": "destroy",
}, },
name='{basename}-detail', name="{basename}-detail",
detail=True, detail=True,
initkwargs={'suffix': 'Instance'} initkwargs={"suffix": "Instance"},
), ),
# Dynamically generated detail routes. Generated using # Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset. # @action(detail=True) decorator on methods of the viewset.
DynamicRoute( DynamicRoute(
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$', url=r"^{prefix}/{lookup}/{url_path}{trailing_slash}$",
name='{basename}-{url_name}', name="{basename}-{url_name}",
detail=True, detail=True,
initkwargs={} initkwargs={},
), ),
] ]
def __init__(self, trailing_slash=True): def __init__(self, trailing_slash=True):
self.trailing_slash = '/' if trailing_slash else '' self.trailing_slash = "/" if trailing_slash else ""
super(SimpleRouter, self).__init__() super(SimpleRouter, self).__init__()
def get_default_basename(self, viewset): def get_default_basename(self, viewset):
@ -180,11 +179,13 @@ class SimpleRouter(BaseRouter):
If `basename` is not specified, attempt to automatically determine If `basename` is not specified, attempt to automatically determine
it from the viewset. it from the viewset.
""" """
queryset = getattr(viewset, 'queryset', None) queryset = getattr(viewset, "queryset", None)
assert queryset is not None, '`basename` argument not specified, and could ' \ assert queryset is not None, (
'not automatically determine the name from the viewset, as ' \ "`basename` argument not specified, and could "
'it does not have a `.queryset` attribute.' "not automatically determine the name from the viewset, as "
"it does not have a `.queryset` attribute."
)
return queryset.model._meta.object_name.lower() return queryset.model._meta.object_name.lower()
@ -196,18 +197,29 @@ class SimpleRouter(BaseRouter):
""" """
# converting to list as iterables are good for one pass, known host needs to be checked again and again for # converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions. # different functions.
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])) known_actions = list(
flatten(
[
route.mapping.values()
for route in self.routes
if isinstance(route, Route)
]
)
)
extra_actions = viewset.get_extra_actions() extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list # checking action names against the known actions list
not_allowed = [ not_allowed = [
action.__name__ for action in extra_actions action.__name__
for action in extra_actions
if action.__name__ in known_actions if action.__name__ in known_actions
] ]
if not_allowed: if not_allowed:
msg = ('Cannot use the @action decorator on the following ' msg = (
'methods, as they are existing routes: %s') "Cannot use the @action decorator on the following "
raise ImproperlyConfigured(msg % ', '.join(not_allowed)) "methods, as they are existing routes: %s"
)
raise ImproperlyConfigured(msg % ", ".join(not_allowed))
# partition detail and list actions # partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail] detail_actions = [action for action in extra_actions if action.detail]
@ -216,9 +228,13 @@ class SimpleRouter(BaseRouter):
routes = [] routes = []
for route in self.routes: for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail: if isinstance(route, DynamicRoute) and route.detail:
routes += [self._get_dynamic_route(route, action) for action in detail_actions] routes += [
self._get_dynamic_route(route, action) for action in detail_actions
]
elif isinstance(route, DynamicRoute) and not route.detail: elif isinstance(route, DynamicRoute) and not route.detail:
routes += [self._get_dynamic_route(route, action) for action in list_actions] routes += [
self._get_dynamic_route(route, action) for action in list_actions
]
else: else:
routes.append(route) routes.append(route)
@ -231,9 +247,9 @@ class SimpleRouter(BaseRouter):
url_path = escape_curly_brackets(action.url_path) url_path = escape_curly_brackets(action.url_path)
return Route( return Route(
url=route.url.replace('{url_path}', url_path), url=route.url.replace("{url_path}", url_path),
mapping=action.mapping, mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name), name=route.name.replace("{url_name}", action.url_name),
detail=route.detail, detail=route.detail,
initkwargs=initkwargs, initkwargs=initkwargs,
) )
@ -250,7 +266,7 @@ class SimpleRouter(BaseRouter):
bound_methods[method] = action bound_methods[method] = action
return bound_methods return bound_methods
def get_lookup_regex(self, viewset, lookup_prefix=''): def get_lookup_regex(self, viewset, lookup_prefix=""):
""" """
Given a viewset, return the portion of URL regex that is used Given a viewset, return the portion of URL regex that is used
to match against a single instance. to match against a single instance.
@ -261,16 +277,16 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers https://github.com/alanjds/drf-nested-routers
""" """
base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})' base_regex = "(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})"
# Use `pk` as default field, unset set. Default regex should not # Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries. # consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk') lookup_field = getattr(viewset, "lookup_field", "pk")
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field lookup_url_kwarg = getattr(viewset, "lookup_url_kwarg", None) or lookup_field
lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+') lookup_value = getattr(viewset, "lookup_value_regex", "[^/.]+")
return base_regex.format( return base_regex.format(
lookup_prefix=lookup_prefix, lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg, lookup_url_kwarg=lookup_url_kwarg,
lookup_value=lookup_value lookup_value=lookup_value,
) )
def get_urls(self): def get_urls(self):
@ -292,23 +308,18 @@ class SimpleRouter(BaseRouter):
# Build the url pattern # Build the url pattern
regex = route.url.format( regex = route.url.format(
prefix=prefix, prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash
lookup=lookup,
trailing_slash=self.trailing_slash
) )
# If there is no prefix, the first part of the url is probably # If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app, # controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give # so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'. # warnings and (B) generate URLS that will require using '//'.
if not prefix and regex[:2] == '^/': if not prefix and regex[:2] == "^/":
regex = '^' + regex[2:] regex = "^" + regex[2:]
initkwargs = route.initkwargs.copy() initkwargs = route.initkwargs.copy()
initkwargs.update({ initkwargs.update({"basename": basename, "detail": route.detail})
'basename': basename,
'detail': route.detail,
})
view = viewset.as_view(mapping, **initkwargs) view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename) name = route.name.format(basename=basename)
@ -321,6 +332,7 @@ class APIRootView(views.APIView):
""" """
The default basic root view for DefaultRouter The default basic root view for DefaultRouter
""" """
_ignore_model_permissions = True _ignore_model_permissions = True
schema = None # exclude from schema schema = None # exclude from schema
api_root_dict = None api_root_dict = None
@ -331,14 +343,14 @@ class APIRootView(views.APIView):
namespace = request.resolver_match.namespace namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items(): for key, url_name in self.api_root_dict.items():
if namespace: if namespace:
url_name = namespace + ':' + url_name url_name = namespace + ":" + url_name
try: try:
ret[key] = reverse( ret[key] = reverse(
url_name, url_name,
args=args, args=args,
kwargs=kwargs, kwargs=kwargs,
request=request, request=request,
format=kwargs.get('format', None) format=kwargs.get("format", None),
) )
except NoReverseMatch: except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes. # Don't bail out if eg. no list routes exist, only detail routes.
@ -352,17 +364,18 @@ class DefaultRouter(SimpleRouter):
The default router extends the SimpleRouter, but also adds in a default The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs. API root view, and adds format suffix patterns to the URLs.
""" """
include_root_view = True include_root_view = True
include_format_suffixes = True include_format_suffixes = True
root_view_name = 'api-root' root_view_name = "api-root"
default_schema_renderers = None default_schema_renderers = None
APIRootView = APIRootView APIRootView = APIRootView
APISchemaView = SchemaView APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
if 'root_renderers' in kwargs: if "root_renderers" in kwargs:
self.root_renderers = kwargs.pop('root_renderers') self.root_renderers = kwargs.pop("root_renderers")
else: else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super(DefaultRouter, self).__init__(*args, **kwargs) super(DefaultRouter, self).__init__(*args, **kwargs)
@ -387,7 +400,7 @@ class DefaultRouter(SimpleRouter):
if self.include_root_view: if self.include_root_view:
view = self.get_api_root_view(api_urls=urls) view = self.get_api_root_view(api_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name) root_url = url(r"^$", view, name=self.root_view_name)
urls.append(root_url) urls.append(root_url)
if self.include_format_suffixes: if self.include_format_suffixes:

View File

@ -27,18 +27,29 @@ from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa
def get_schema_view( def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None, title=None,
public=False, patterns=None, generator_class=SchemaGenerator, url=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, description=None,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): urlconf=None,
renderer_classes=None,
public=False,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
):
""" """
Return a schema view. Return a schema view.
""" """
# Avoid import cycle on APIView # Avoid import cycle on APIView
from .views import SchemaView from .views import SchemaView
generator = generator_class( generator = generator_class(
title=title, url=url, description=description, title=title,
urlconf=urlconf, patterns=patterns, url=url,
description=description,
urlconf=urlconf,
patterns=patterns,
) )
return SchemaView.as_view( return SchemaView.as_view(
renderer_classes=renderer_classes, renderer_classes=renderer_classes,

View File

@ -15,7 +15,11 @@ from django.utils import six
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import ( from rest_framework.compat import (
URLPattern, URLResolver, coreapi, coreschema, get_original_route URLPattern,
URLResolver,
coreapi,
coreschema,
get_original_route,
) )
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -25,7 +29,7 @@ from .utils import is_list_view
def common_path(paths): def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths] split_paths = [path.strip("/").split("/") for path in paths]
s1 = min(split_paths) s1 = min(split_paths)
s2 = max(split_paths) s2 = max(split_paths)
common = s1 common = s1
@ -33,7 +37,7 @@ def common_path(paths):
if c != s2[i]: if c != s2[i]:
common = s1[:i] common = s1[:i]
break break
return '/' + '/'.join(common) return "/" + "/".join(common)
def get_pk_name(model): def get_pk_name(model):
@ -47,7 +51,8 @@ def is_api_view(callback):
""" """
# Avoid import cycle on APIView # Avoid import cycle on APIView
from rest_framework.views import APIView from rest_framework.views import APIView
cls = getattr(callback, 'cls', None)
cls = getattr(callback, "cls", None)
return (cls is not None) and issubclass(cls, APIView) return (cls is not None) and issubclass(cls, APIView)
@ -78,7 +83,7 @@ class LinkNode(OrderedDict):
current_val = self.methods_counter[preferred_key] current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1 self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val) key = "{}_{}".format(preferred_key, current_val)
if key not in self: if key not in self:
return key return key
@ -101,9 +106,7 @@ def insert_into(target, keys, value):
target.links.append((keys[-1], value)) target.links.append((keys[-1], value))
except TypeError: except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format( msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url, value_url=value.url, target_url=target.url, keys=keys
target_url=target.url,
keys=keys
) )
raise ValueError(msg) raise ValueError(msg)
@ -119,24 +122,25 @@ def distribute_links(obj):
def is_custom_action(action): def is_custom_action(action):
return action not in { return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' "retrieve",
"list",
"create",
"update",
"partial_update",
"destroy",
} }
def endpoint_ordering(endpoint): def endpoint_ordering(endpoint):
path, method, callback = endpoint path, method, callback = endpoint
method_priority = { method_priority = {"GET": 0, "POST": 1, "PUT": 2, "PATCH": 3, "DELETE": 4}.get(
'GET': 0, method, 5
'POST': 1, )
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
return (path, method_priority) return (path, method_priority)
_PATH_PARAMETER_COMPONENT_RE = re.compile( _PATH_PARAMETER_COMPONENT_RE = re.compile(
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>' r"<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>"
) )
@ -144,6 +148,7 @@ class EndpointEnumerator(object):
""" """
A class to determine the available API endpoints that a project exposes. A class to determine the available API endpoints that a project exposes.
""" """
def __init__(self, patterns=None, urlconf=None): def __init__(self, patterns=None, urlconf=None):
if patterns is None: if patterns is None:
if urlconf is None: if urlconf is None:
@ -159,7 +164,7 @@ class EndpointEnumerator(object):
self.patterns = patterns self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''): def get_api_endpoints(self, patterns=None, prefix=""):
""" """
Return a list of all available API endpoints by inspecting the URL conf. Return a list of all available API endpoints by inspecting the URL conf.
""" """
@ -180,8 +185,7 @@ class EndpointEnumerator(object):
elif isinstance(pattern, URLResolver): elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints( nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns, patterns=pattern.url_patterns, prefix=path_regex
prefix=path_regex
) )
api_endpoints.extend(nested_endpoints) api_endpoints.extend(nested_endpoints)
@ -196,7 +200,7 @@ class EndpointEnumerator(object):
path = simplify_regex(path_regex) path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format # Strip Django 2.0 convertors as they are incompatible with uritemplate format
path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path) path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r"{\g<parameter>}", path)
return path return path
def should_include_endpoint(self, path, callback): def should_include_endpoint(self, path, callback):
@ -209,11 +213,11 @@ class EndpointEnumerator(object):
if callback.cls.schema is None: if callback.cls.schema is None:
return False return False
if 'schema' in callback.initkwargs: if "schema" in callback.initkwargs:
if callback.initkwargs['schema'] is None: if callback.initkwargs["schema"] is None:
return False return False
if path.endswith('.{format}') or path.endswith('.{format}/'): if path.endswith(".{format}") or path.endswith(".{format}/"):
return False # Ignore .json style URLs. return False # Ignore .json style URLs.
return True return True
@ -222,24 +226,24 @@ class EndpointEnumerator(object):
""" """
Return a list of the valid HTTP methods for this endpoint. Return a list of the valid HTTP methods for this endpoint.
""" """
if hasattr(callback, 'actions'): if hasattr(callback, "actions"):
actions = set(callback.actions) actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names) http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names] methods = [method.upper() for method in actions & http_method_names]
else: else:
methods = callback.cls().allowed_methods methods = callback.cls().allowed_methods
return [method for method in methods if method not in ('OPTIONS', 'HEAD')] return [method for method in methods if method not in ("OPTIONS", "HEAD")]
class SchemaGenerator(object): class SchemaGenerator(object):
# Map HTTP methods onto actions. # Map HTTP methods onto actions.
default_mapping = { default_mapping = {
'get': 'retrieve', "get": "retrieve",
'post': 'create', "post": "create",
'put': 'update', "put": "update",
'patch': 'partial_update', "patch": "partial_update",
'delete': 'destroy', "delete": "destroy",
} }
endpoint_inspector_cls = EndpointEnumerator endpoint_inspector_cls = EndpointEnumerator
@ -253,12 +257,14 @@ class SchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'. # Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): def __init__(
assert coreapi, '`coreapi` must be installed for schema support.' self, title=None, url=None, description=None, patterns=None, urlconf=None
assert coreschema, '`coreschema` must be installed for schema support.' ):
assert coreapi, "`coreapi` must be installed for schema support."
assert coreschema, "`coreschema` must be installed for schema support."
if url and not url.endswith('/'): if url and not url.endswith("/"):
url += '/' url += "/"
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
@ -288,8 +294,7 @@ class SchemaGenerator(object):
distribute_links(links) distribute_links(links)
return coreapi.Document( return coreapi.Document(
title=self.title, description=self.description, title=self.title, description=self.description, url=url, content=links
url=url, content=links
) )
def get_links(self, request=None): def get_links(self, request=None):
@ -317,7 +322,7 @@ class SchemaGenerator(object):
if not self.has_view_permissions(path, method, view): if not self.has_view_permissions(path, method, view):
continue continue
link = view.schema.get_link(path, method, base_url=self.url) link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):] subpath = path[len(prefix) :]
keys = self.get_keys(subpath, method, view) keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link) insert_into(links, keys, link)
@ -342,35 +347,35 @@ class SchemaGenerator(object):
""" """
prefixes = [] prefixes = []
for path in paths: for path in paths:
components = path.strip('/').split('/') components = path.strip("/").split("/")
initial_components = [] initial_components = []
for component in components: for component in components:
if '{' in component: if "{" in component:
break break
initial_components.append(component) initial_components.append(component)
prefix = '/'.join(initial_components[:-1]) prefix = "/".join(initial_components[:-1])
if not prefix: if not prefix:
# We can just break early in the case that there's at least # We can just break early in the case that there's at least
# one URL that doesn't have a path prefix. # one URL that doesn't have a path prefix.
return '/' return "/"
prefixes.append('/' + prefix + '/') prefixes.append("/" + prefix + "/")
return common_path(prefixes) return common_path(prefixes)
def create_view(self, callback, method, request=None): def create_view(self, callback, method, request=None):
""" """
Given a callback, return an actual view instance. Given a callback, return an actual view instance.
""" """
view = callback.cls(**getattr(callback, 'initkwargs', {})) view = callback.cls(**getattr(callback, "initkwargs", {}))
view.args = () view.args = ()
view.kwargs = {} view.kwargs = {}
view.format_kwarg = None view.format_kwarg = None
view.request = None view.request = None
view.action_map = getattr(callback, 'actions', None) view.action_map = getattr(callback, "actions", None)
actions = getattr(callback, 'actions', None) actions = getattr(callback, "actions", None)
if actions is not None: if actions is not None:
if method == 'OPTIONS': if method == "OPTIONS":
view.action = 'metadata' view.action = "metadata"
else: else:
view.action = actions.get(method.lower()) view.action = actions.get(method.lower())
@ -398,14 +403,14 @@ class SchemaGenerator(object):
where possible. This is cleaner for an external representation. where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key") (Ie. "this is an identifier", not "this is a database primary key")
""" """
if not self.coerce_path_pk or '{pk}' not in path: if not self.coerce_path_pk or "{pk}" not in path:
return path return path
model = getattr(getattr(view, 'queryset', None), 'model', None) model = getattr(getattr(view, "queryset", None), "model", None)
if model: if model:
field_name = get_pk_name(model) field_name = get_pk_name(model)
else: else:
field_name = 'id' field_name = "id"
return path.replace('{pk}', '{%s}' % field_name) return path.replace("{pk}", "{%s}" % field_name)
# Method for generating the link layout.... # Method for generating the link layout....
@ -421,20 +426,20 @@ class SchemaGenerator(object):
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
""" """
if hasattr(view, 'action'): if hasattr(view, "action"):
# Viewsets have explicitly named actions. # Viewsets have explicitly named actions.
action = view.action action = view.action
else: else:
# Views have no associated action, so we determine one from the method. # Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view): if is_list_view(subpath, method, view):
action = 'list' action = "list"
else: else:
action = self.default_mapping[method.lower()] action = self.default_mapping[method.lower()]
named_path_components = [ named_path_components = [
component for component component
in subpath.strip('/').split('/') for component in subpath.strip("/").split("/")
if '{' not in component if "{" not in component
] ]
if is_custom_action(action): if is_custom_action(action):

View File

@ -21,46 +21,38 @@ from rest_framework.utils import formatting
from .utils import is_list_view from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
header_regex = re.compile("^[a-zA-Z][0-9A-Za-z_]*:")
def field_to_schema(field): def field_to_schema(field):
title = force_text(field.label) if field.label else '' title = force_text(field.label) if field.label else ""
description = force_text(field.help_text) if field.help_text else '' description = force_text(field.help_text) if field.help_text else ""
if isinstance(field, (serializers.ListSerializer, serializers.ListField)): if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child) child_schema = field_to_schema(field.child)
return coreschema.Array( return coreschema.Array(
items=child_schema, items=child_schema, title=title, description=description
title=title,
description=description
) )
elif isinstance(field, serializers.DictField): elif isinstance(field, serializers.DictField):
return coreschema.Object( return coreschema.Object(title=title, description=description)
title=title,
description=description
)
elif isinstance(field, serializers.Serializer): elif isinstance(field, serializers.Serializer):
return coreschema.Object( return coreschema.Object(
properties=OrderedDict([ properties=OrderedDict(
(key, field_to_schema(value)) [(key, field_to_schema(value)) for key, value in field.fields.items()]
for key, value ),
in field.fields.items()
]),
title=title, title=title,
description=description description=description,
) )
elif isinstance(field, serializers.ManyRelatedField): elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation) related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array( return coreschema.Array(
items=related_field_schema, items=related_field_schema, title=title, description=description
title=title,
description=description
) )
elif isinstance(field, serializers.PrimaryKeyRelatedField): elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None) model = getattr(field.queryset, "model", None)
if model is not None: if model is not None:
model_field = model._meta.pk model_field = model._meta.pk
if isinstance(model_field, models.AutoField): if isinstance(model_field, models.AutoField):
@ -72,13 +64,11 @@ def field_to_schema(field):
return coreschema.Array( return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)), items=coreschema.Enum(enum=list(field.choices)),
title=title, title=title,
description=description description=description,
) )
elif isinstance(field, serializers.ChoiceField): elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum( return coreschema.Enum(
enum=list(field.choices), enum=list(field.choices), title=title, description=description
title=title,
description=description
) )
elif isinstance(field, serializers.BooleanField): elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description) return coreschema.Boolean(title=title, description=description)
@ -87,25 +77,17 @@ def field_to_schema(field):
elif isinstance(field, serializers.IntegerField): elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description) return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField): elif isinstance(field, serializers.DateField):
return coreschema.String( return coreschema.String(title=title, description=description, format="date")
title=title,
description=description,
format='date'
)
elif isinstance(field, serializers.DateTimeField): elif isinstance(field, serializers.DateTimeField):
return coreschema.String( return coreschema.String(
title=title, title=title, description=description, format="date-time"
description=description,
format='date-time'
) )
elif isinstance(field, serializers.JSONField): elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description) return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html': if field.style.get("base_template") == "textarea.html":
return coreschema.String( return coreschema.String(
title=title, title=title, description=description, format="textarea"
description=description,
format='textarea'
) )
return coreschema.String(title=title, description=description) return coreschema.String(title=title, description=description)
@ -113,15 +95,14 @@ def field_to_schema(field):
def get_pk_description(model, model_field): def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField): if isinstance(model_field, models.AutoField):
value_type = _('unique integer value') value_type = _("unique integer value")
elif isinstance(model_field, models.UUIDField): elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string') value_type = _("UUID string")
else: else:
value_type = _('unique value') value_type = _("unique value")
return _('A {value_type} identifying this {name}.').format( return _("A {value_type} identifying this {name}.").format(
value_type=value_type, value_type=value_type, name=model._meta.verbose_name
name=model._meta.verbose_name,
) )
@ -200,6 +181,7 @@ class AutoSchema(ViewInspector):
Responsible for per-view introspection and schema generation. Responsible for per-view introspection and schema generation.
""" """
def __init__(self, manual_fields=None): def __init__(self, manual_fields=None):
""" """
Parameters: Parameters:
@ -221,14 +203,14 @@ class AutoSchema(ViewInspector):
manual_fields = self.get_manual_fields(path, method) manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields) fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]): if fields and any([field.location in ("form", "body") for field in fields]):
encoding = self.get_encoding(path, method) encoding = self.get_encoding(path, method)
else: else:
encoding = None encoding = None
description = self.get_description(path, method) description = self.get_description(path, method)
if base_url and path.startswith('/'): if base_url and path.startswith("/"):
path = path[1:] path = path[1:]
return coreapi.Link( return coreapi.Link(
@ -236,7 +218,7 @@ class AutoSchema(ViewInspector):
action=method.lower(), action=method.lower(),
encoding=encoding, encoding=encoding,
fields=fields, fields=fields,
description=description description=description,
) )
def get_description(self, path, method): def get_description(self, path, method):
@ -248,25 +230,31 @@ class AutoSchema(ViewInspector):
""" """
view = self.view view = self.view
method_name = getattr(view, 'action', method.lower()) method_name = getattr(view, "action", method.lower())
method_docstring = getattr(view, method_name, None).__doc__ method_docstring = getattr(view, method_name, None).__doc__
if method_docstring: if method_docstring:
# An explicit docstring on the method or action. # An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) return self._get_description_section(
view, method.lower(), formatting.dedent(smart_text(method_docstring))
)
else: else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) return self._get_description_section(
view,
getattr(view, "action", method.lower()),
view.get_view_description(),
)
def _get_description_section(self, view, header, description): def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()] lines = [line for line in description.splitlines()]
current_section = '' current_section = ""
sections = {'': ''} sections = {"": ""}
for line in lines: for line in lines:
if header_regex.match(line): if header_regex.match(line):
current_section, seperator, lead = line.partition(':') current_section, seperator, lead = line.partition(":")
sections[current_section] = lead.strip() sections[current_section] = lead.strip()
else: else:
sections[current_section] += '\n' + line sections[current_section] += "\n" + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
@ -275,7 +263,7 @@ class AutoSchema(ViewInspector):
if header in coerce_method_names: if header in coerce_method_names:
if coerce_method_names[header] in sections: if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip() return sections[coerce_method_names[header]].strip()
return sections[''].strip() return sections[""].strip()
def get_path_fields(self, path, method): def get_path_fields(self, path, method):
""" """
@ -283,12 +271,12 @@ class AutoSchema(ViewInspector):
templated path variables. templated path variables.
""" """
view = self.view view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None) model = getattr(getattr(view, "queryset", None), "model", None)
fields = [] fields = []
for variable in uritemplate.variables(path): for variable in uritemplate.variables(path):
title = '' title = ""
description = '' description = ""
schema_cls = coreschema.String schema_cls = coreschema.String
kwargs = {} kwargs = {}
if model is not None: if model is not None:
@ -306,16 +294,19 @@ class AutoSchema(ViewInspector):
elif model_field is not None and model_field.primary_key: elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field) description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: if (
kwargs['pattern'] = view.lookup_value_regex hasattr(view, "lookup_value_regex")
and view.lookup_field == variable
):
kwargs["pattern"] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField): elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer schema_cls = coreschema.Integer
field = coreapi.Field( field = coreapi.Field(
name=variable, name=variable,
location='path', location="path",
required=True, required=True,
schema=schema_cls(title=title, description=description, **kwargs) schema=schema_cls(title=title, description=description, **kwargs),
) )
fields.append(field) fields.append(field)
@ -328,28 +319,29 @@ class AutoSchema(ViewInspector):
""" """
view = self.view view = self.view
if method not in ('PUT', 'PATCH', 'POST'): if method not in ("PUT", "PATCH", "POST"):
return [] return []
if not hasattr(view, 'get_serializer'): if not hasattr(view, "get_serializer"):
return [] return []
try: try:
serializer = view.get_serializer() serializer = view.get_serializer()
except exceptions.APIException: except exceptions.APIException:
serializer = None serializer = None
warnings.warn('{}.get_serializer() raised an exception during ' warnings.warn(
'schema generation. Serializer fields will not be ' "{}.get_serializer() raised an exception during "
'generated for {} {}.' "schema generation. Serializer fields will not be "
.format(view.__class__.__name__, method, path)) "generated for {} {}.".format(view.__class__.__name__, method, path)
)
if isinstance(serializer, serializers.ListSerializer): if isinstance(serializer, serializers.ListSerializer):
return [ return [
coreapi.Field( coreapi.Field(
name='data', name="data",
location='body', location="body",
required=True, required=True,
schema=coreschema.Array() schema=coreschema.Array(),
) )
] ]
@ -361,12 +353,12 @@ class AutoSchema(ViewInspector):
if field.read_only or isinstance(field, serializers.HiddenField): if field.read_only or isinstance(field, serializers.HiddenField):
continue continue
required = field.required and method != 'PATCH' required = field.required and method != "PATCH"
field = coreapi.Field( field = coreapi.Field(
name=field.field_name, name=field.field_name,
location='form', location="form",
required=required, required=required,
schema=field_to_schema(field) schema=field_to_schema(field),
) )
fields.append(field) fields.append(field)
@ -378,7 +370,7 @@ class AutoSchema(ViewInspector):
if not is_list_view(path, method, view): if not is_list_view(path, method, view):
return [] return []
pagination = getattr(view, 'pagination_class', None) pagination = getattr(view, "pagination_class", None)
if not pagination: if not pagination:
return [] return []
@ -397,11 +389,17 @@ class AutoSchema(ViewInspector):
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience. to allow changes based on user experience.
""" """
if getattr(self.view, 'filter_backends', None) is None: if getattr(self.view, "filter_backends", None) is None:
return False return False
if hasattr(self.view, 'action'): if hasattr(self.view, "action"):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] return self.view.action in [
"list",
"retrieve",
"update",
"partial_update",
"destroy",
]
return method.lower() in ["get", "put", "patch", "delete"] return method.lower() in ["get", "put", "patch", "delete"]
@ -447,18 +445,18 @@ class AutoSchema(ViewInspector):
# Core API supports the following request encodings over HTTP... # Core API supports the following request encodings over HTTP...
supported_media_types = { supported_media_types = {
'application/json', "application/json",
'application/x-www-form-urlencoded', "application/x-www-form-urlencoded",
'multipart/form-data', "multipart/form-data",
} }
parser_classes = getattr(view, 'parser_classes', []) parser_classes = getattr(view, "parser_classes", [])
for parser_class in parser_classes: for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None) media_type = getattr(parser_class, "media_type", None)
if media_type in supported_media_types: if media_type in supported_media_types:
return media_type return media_type
# Raw binary uploads are supported with "application/octet-stream" # Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*': if media_type == "*/*":
return 'application/octet-stream' return "application/octet-stream"
return None return None
@ -468,7 +466,8 @@ class ManualSchema(ViewInspector):
Allows providing a list of coreapi.Fields, Allows providing a list of coreapi.Fields,
plus an optional description. plus an optional description.
""" """
def __init__(self, fields, description='', encoding=None):
def __init__(self, fields, description="", encoding=None):
""" """
Parameters: Parameters:
@ -476,14 +475,16 @@ class ManualSchema(ViewInspector):
* `description`: String description for view. Optional. * `description`: String description for view. Optional.
""" """
super(ManualSchema, self).__init__() super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" assert all(
isinstance(f, coreapi.Field) for f in fields
), "`fields` must be a list of coreapi.Field instances"
self._fields = fields self._fields = fields
self._description = description self._description = description
self._encoding = encoding self._encoding = encoding
def get_link(self, path, method, base_url): def get_link(self, path, method, base_url):
if base_url and path.startswith('/'): if base_url and path.startswith("/"):
path = path[1:] path = path[1:]
return coreapi.Link( return coreapi.Link(
@ -491,21 +492,22 @@ class ManualSchema(ViewInspector):
action=method.lower(), action=method.lower(),
encoding=self._encoding, encoding=self._encoding,
fields=self._fields, fields=self._fields,
description=self._description description=self._description,
) )
class DefaultSchema(ViewInspector): class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner): def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner) result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema): if not isinstance(result, DefaultSchema):
return result return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), ( assert issubclass(
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" inspector_class, ViewInspector
) ), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
inspector = inspector_class() inspector = inspector_class()
inspector.view = instance inspector.view = instance
return inspector return inspector

View File

@ -10,15 +10,15 @@ def is_list_view(path, method, view):
""" """
Return True if the given path/method appears to represent a list view. Return True if the given path/method appears to represent a list view.
""" """
if hasattr(view, 'action'): if hasattr(view, "action"):
# Viewsets have an explicitly defined action, which we can inspect. # Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list' return view.action == "list"
if method.lower() != 'get': if method.lower() != "get":
return False return False
if isinstance(view, RetrieveModelMixin): if isinstance(view, RetrieveModelMixin):
return False return False
path_components = path.strip('/').split('/') path_components = path.strip("/").split("/")
if path_components and '{' in path_components[-1]: if path_components and "{" in path_components[-1]:
return False return False
return True return True

View File

@ -21,7 +21,7 @@ class SchemaView(APIView):
if self.renderer_classes is None: if self.renderer_classes is None:
self.renderer_classes = [ self.renderer_classes = [
renderers.OpenAPIRenderer, renderers.OpenAPIRenderer,
renderers.CoreJSONRenderer renderers.CoreJSONRenderer,
] ]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer] self.renderer_classes += [renderers.BrowsableAPIRenderer]

File diff suppressed because it is too large Load Diff

View File

@ -28,135 +28,109 @@ from django.utils import six
from rest_framework import ISO_8601 from rest_framework import ISO_8601
DEFAULTS = { DEFAULTS = {
# Base API policies # Base API policies
'DEFAULT_RENDERER_CLASSES': ( "DEFAULT_RENDERER_CLASSES": (
'rest_framework.renderers.JSONRenderer', "rest_framework.renderers.JSONRenderer",
'rest_framework.renderers.BrowsableAPIRenderer', "rest_framework.renderers.BrowsableAPIRenderer",
), ),
'DEFAULT_PARSER_CLASSES': ( "DEFAULT_PARSER_CLASSES": (
'rest_framework.parsers.JSONParser', "rest_framework.parsers.JSONParser",
'rest_framework.parsers.FormParser', "rest_framework.parsers.FormParser",
'rest_framework.parsers.MultiPartParser' "rest_framework.parsers.MultiPartParser",
), ),
'DEFAULT_AUTHENTICATION_CLASSES': ( "DEFAULT_AUTHENTICATION_CLASSES": (
'rest_framework.authentication.SessionAuthentication', "rest_framework.authentication.SessionAuthentication",
'rest_framework.authentication.BasicAuthentication' "rest_framework.authentication.BasicAuthentication",
), ),
'DEFAULT_PERMISSION_CLASSES': ( "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",),
'rest_framework.permissions.AllowAny', "DEFAULT_THROTTLE_CLASSES": (),
), "DEFAULT_CONTENT_NEGOTIATION_CLASS": "rest_framework.negotiation.DefaultContentNegotiation",
'DEFAULT_THROTTLE_CLASSES': (), "DEFAULT_METADATA_CLASS": "rest_framework.metadata.SimpleMetadata",
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', "DEFAULT_VERSIONING_CLASS": None,
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
'DEFAULT_VERSIONING_CLASS': None,
# Generic view behavior # Generic view behavior
'DEFAULT_PAGINATION_CLASS': None, "DEFAULT_PAGINATION_CLASS": None,
'DEFAULT_FILTER_BACKENDS': (), "DEFAULT_FILTER_BACKENDS": (),
# Schema # Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.AutoSchema",
# Throttling # Throttling
'DEFAULT_THROTTLE_RATES': { "DEFAULT_THROTTLE_RATES": {"user": None, "anon": None},
'user': None, "NUM_PROXIES": None,
'anon': None,
},
'NUM_PROXIES': None,
# Pagination # Pagination
'PAGE_SIZE': None, "PAGE_SIZE": None,
# Filtering # Filtering
'SEARCH_PARAM': 'search', "SEARCH_PARAM": "search",
'ORDERING_PARAM': 'ordering', "ORDERING_PARAM": "ordering",
# Versioning # Versioning
'DEFAULT_VERSION': None, "DEFAULT_VERSION": None,
'ALLOWED_VERSIONS': None, "ALLOWED_VERSIONS": None,
'VERSION_PARAM': 'version', "VERSION_PARAM": "version",
# Authentication # Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', "UNAUTHENTICATED_USER": "django.contrib.auth.models.AnonymousUser",
'UNAUTHENTICATED_TOKEN': None, "UNAUTHENTICATED_TOKEN": None,
# View configuration # View configuration
'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', "VIEW_NAME_FUNCTION": "rest_framework.views.get_view_name",
'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', "VIEW_DESCRIPTION_FUNCTION": "rest_framework.views.get_view_description",
# Exception handling # Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', "EXCEPTION_HANDLER": "rest_framework.views.exception_handler",
'NON_FIELD_ERRORS_KEY': 'non_field_errors', "NON_FIELD_ERRORS_KEY": "non_field_errors",
# Testing # Testing
'TEST_REQUEST_RENDERER_CLASSES': ( "TEST_REQUEST_RENDERER_CLASSES": (
'rest_framework.renderers.MultiPartRenderer', "rest_framework.renderers.MultiPartRenderer",
'rest_framework.renderers.JSONRenderer' "rest_framework.renderers.JSONRenderer",
), ),
'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', "TEST_REQUEST_DEFAULT_FORMAT": "multipart",
# Hyperlink settings # Hyperlink settings
'URL_FORMAT_OVERRIDE': 'format', "URL_FORMAT_OVERRIDE": "format",
'FORMAT_SUFFIX_KWARG': 'format', "FORMAT_SUFFIX_KWARG": "format",
'URL_FIELD_NAME': 'url', "URL_FIELD_NAME": "url",
# Input and output formats # Input and output formats
'DATE_FORMAT': ISO_8601, "DATE_FORMAT": ISO_8601,
'DATE_INPUT_FORMATS': (ISO_8601,), "DATE_INPUT_FORMATS": (ISO_8601,),
"DATETIME_FORMAT": ISO_8601,
'DATETIME_FORMAT': ISO_8601, "DATETIME_INPUT_FORMATS": (ISO_8601,),
'DATETIME_INPUT_FORMATS': (ISO_8601,), "TIME_FORMAT": ISO_8601,
"TIME_INPUT_FORMATS": (ISO_8601,),
'TIME_FORMAT': ISO_8601,
'TIME_INPUT_FORMATS': (ISO_8601,),
# Encoding # Encoding
'UNICODE_JSON': True, "UNICODE_JSON": True,
'COMPACT_JSON': True, "COMPACT_JSON": True,
'STRICT_JSON': True, "STRICT_JSON": True,
'COERCE_DECIMAL_TO_STRING': True, "COERCE_DECIMAL_TO_STRING": True,
'UPLOADED_FILES_USE_URL': True, "UPLOADED_FILES_USE_URL": True,
# Browseable API # Browseable API
'HTML_SELECT_CUTOFF': 1000, "HTML_SELECT_CUTOFF": 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", "HTML_SELECT_CUTOFF_TEXT": "More than {count} items...",
# Schemas # Schemas
'SCHEMA_COERCE_PATH_PK': True, "SCHEMA_COERCE_PATH_PK": True,
'SCHEMA_COERCE_METHOD_NAMES': { "SCHEMA_COERCE_METHOD_NAMES": {"retrieve": "read", "destroy": "delete"},
'retrieve': 'read',
'destroy': 'delete'
},
} }
# List of settings that may be in string import notation. # List of settings that may be in string import notation.
IMPORT_STRINGS = ( IMPORT_STRINGS = (
'DEFAULT_RENDERER_CLASSES', "DEFAULT_RENDERER_CLASSES",
'DEFAULT_PARSER_CLASSES', "DEFAULT_PARSER_CLASSES",
'DEFAULT_AUTHENTICATION_CLASSES', "DEFAULT_AUTHENTICATION_CLASSES",
'DEFAULT_PERMISSION_CLASSES', "DEFAULT_PERMISSION_CLASSES",
'DEFAULT_THROTTLE_CLASSES', "DEFAULT_THROTTLE_CLASSES",
'DEFAULT_CONTENT_NEGOTIATION_CLASS', "DEFAULT_CONTENT_NEGOTIATION_CLASS",
'DEFAULT_METADATA_CLASS', "DEFAULT_METADATA_CLASS",
'DEFAULT_VERSIONING_CLASS', "DEFAULT_VERSIONING_CLASS",
'DEFAULT_PAGINATION_CLASS', "DEFAULT_PAGINATION_CLASS",
'DEFAULT_FILTER_BACKENDS', "DEFAULT_FILTER_BACKENDS",
'DEFAULT_SCHEMA_CLASS', "DEFAULT_SCHEMA_CLASS",
'EXCEPTION_HANDLER', "EXCEPTION_HANDLER",
'TEST_REQUEST_RENDERER_CLASSES', "TEST_REQUEST_RENDERER_CLASSES",
'UNAUTHENTICATED_USER', "UNAUTHENTICATED_USER",
'UNAUTHENTICATED_TOKEN', "UNAUTHENTICATED_TOKEN",
'VIEW_NAME_FUNCTION', "VIEW_NAME_FUNCTION",
'VIEW_DESCRIPTION_FUNCTION' "VIEW_DESCRIPTION_FUNCTION",
) )
# List of settings that have been removed # List of settings that have been removed
REMOVED_SETTINGS = ( REMOVED_SETTINGS = ("PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY")
"PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY",
)
def perform_import(val, setting_name): def perform_import(val, setting_name):
@ -179,11 +153,16 @@ def import_from_string(val, setting_name):
""" """
try: try:
# Nod to tastypie's use of importlib. # Nod to tastypie's use of importlib.
module_path, class_name = val.rsplit('.', 1) module_path, class_name = val.rsplit(".", 1)
module = import_module(module_path) module = import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for API setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg) raise ImportError(msg)
@ -198,6 +177,7 @@ class APISettings(object):
Any setting with string import paths will be automatically resolved Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal. and return the class, rather than the string literal.
""" """
def __init__(self, user_settings=None, defaults=None, import_strings=None): def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings: if user_settings:
self._user_settings = self.__check_user_settings(user_settings) self._user_settings = self.__check_user_settings(user_settings)
@ -207,8 +187,8 @@ class APISettings(object):
@property @property
def user_settings(self): def user_settings(self):
if not hasattr(self, '_user_settings'): if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {}) self._user_settings = getattr(settings, "REST_FRAMEWORK", {})
return self._user_settings return self._user_settings
def __getattr__(self, attr): def __getattr__(self, attr):
@ -235,23 +215,26 @@ class APISettings(object):
SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/" SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/"
for setting in REMOVED_SETTINGS: for setting in REMOVED_SETTINGS:
if setting in user_settings: if setting in user_settings:
raise RuntimeError("The '%s' setting has been removed. Please refer to '%s' for available settings." % (setting, SETTINGS_DOC)) raise RuntimeError(
"The '%s' setting has been removed. Please refer to '%s' for available settings."
% (setting, SETTINGS_DOC)
)
return user_settings return user_settings
def reload(self): def reload(self):
for attr in self._cached_attrs: for attr in self._cached_attrs:
delattr(self, attr) delattr(self, attr)
self._cached_attrs.clear() self._cached_attrs.clear()
if hasattr(self, '_user_settings'): if hasattr(self, "_user_settings"):
delattr(self, '_user_settings') delattr(self, "_user_settings")
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS) api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs): def reload_api_settings(*args, **kwargs):
setting = kwargs['setting'] setting = kwargs["setting"]
if setting == 'REST_FRAMEWORK': if setting == "REST_FRAMEWORK":
api_settings.reload() api_settings.reload()

View File

@ -15,22 +15,23 @@ from rest_framework.compat import apply_markdown, pygments_highlight
from rest_framework.renderers import HTMLFormRenderer from rest_framework.renderers import HTMLFormRenderer
from rest_framework.utils.urls import replace_query_param from rest_framework.utils.urls import replace_query_param
register = template.Library() register = template.Library()
# Regex for adding classes to html snippets # Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@register.tag(name='code') @register.tag(name="code")
def highlight_code(parser, token): def highlight_code(parser, token):
code = token.split_contents()[-1] code = token.split_contents()[-1]
nodelist = parser.parse(('endcode',)) nodelist = parser.parse(("endcode",))
parser.delete_first_token() parser.delete_first_token()
return CodeNode(code, nodelist) return CodeNode(code, nodelist)
class CodeNode(template.Node): class CodeNode(template.Node):
style = 'emacs' style = "emacs"
def __init__(self, lang, code): def __init__(self, lang, code):
self.lang = lang self.lang = lang
@ -43,24 +44,17 @@ class CodeNode(template.Node):
@register.filter() @register.filter()
def with_location(fields, location): def with_location(fields, location):
return [ return [field for field in fields if field.location == location]
field for field in fields
if field.location == location
]
@register.simple_tag @register.simple_tag
def form_for_link(link): def form_for_link(link):
import coreschema import coreschema
properties = OrderedDict([
(field.name, field.schema or coreschema.String()) properties = OrderedDict(
for field in link.fields [(field.name, field.schema or coreschema.String()) for field in link.fields]
]) )
required = [ required = [field.name for field in link.fields if field.required]
field.name
for field in link.fields
if field.required
]
schema = coreschema.Object(properties=properties, required=required) schema = coreschema.Object(properties=properties, required=required)
return mark_safe(coreschema.render_to_form(schema)) return mark_safe(coreschema.render_to_form(schema))
@ -79,14 +73,14 @@ def get_pagination_html(pager):
@register.simple_tag @register.simple_tag
def render_form(serializer, template_pack=None): def render_form(serializer, template_pack=None):
style = {'template_pack': template_pack} if template_pack else {} style = {"template_pack": template_pack} if template_pack else {}
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
return renderer.render(serializer.data, None, {'style': style}) return renderer.render(serializer.data, None, {"style": style})
@register.simple_tag @register.simple_tag
def render_field(field, style): def render_field(field, style):
renderer = style.get('renderer', HTMLFormRenderer()) renderer = style.get("renderer", HTMLFormRenderer())
return renderer.render_field(field, style) return renderer.render_field(field, style)
@ -96,9 +90,9 @@ def optional_login(request):
Include a login snippet if REST framework's login view is in the URLconf. Include a login snippet if REST framework's login view is in the URLconf.
""" """
try: try:
login_url = reverse('rest_framework:login') login_url = reverse("rest_framework:login")
except NoReverseMatch: except NoReverseMatch:
return '' return ""
snippet = "<li><a href='{href}?next={next}'>Log in</a></li>" snippet = "<li><a href='{href}?next={next}'>Log in</a></li>"
snippet = format_html(snippet, href=login_url, next=escape(request.path)) snippet = format_html(snippet, href=login_url, next=escape(request.path))
@ -112,9 +106,9 @@ def optional_docs_login(request):
Include a login snippet if REST framework's login view is in the URLconf. Include a login snippet if REST framework's login view is in the URLconf.
""" """
try: try:
login_url = reverse('rest_framework:login') login_url = reverse("rest_framework:login")
except NoReverseMatch: except NoReverseMatch:
return 'log in' return "log in"
snippet = "<a href='{href}?next={next}'>log in</a>" snippet = "<a href='{href}?next={next}'>log in</a>"
snippet = format_html(snippet, href=login_url, next=escape(request.path)) snippet = format_html(snippet, href=login_url, next=escape(request.path))
@ -128,7 +122,7 @@ def optional_logout(request, user):
Include a logout snippet if REST framework's logout view is in the URLconf. Include a logout snippet if REST framework's logout view is in the URLconf.
""" """
try: try:
logout_url = reverse('rest_framework:logout') logout_url = reverse("rest_framework:logout")
except NoReverseMatch: except NoReverseMatch:
snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user)) snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user))
return mark_safe(snippet) return mark_safe(snippet)
@ -142,7 +136,9 @@ def optional_logout(request, user):
<li><a href='{href}?next={next}'>Log out</a></li> <li><a href='{href}?next={next}'>Log out</a></li>
</ul> </ul>
</li>""" </li>"""
snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path)) snippet = format_html(
snippet, user=escape(user), href=logout_url, next=escape(request.path)
)
return mark_safe(snippet) return mark_safe(snippet)
@ -160,16 +156,13 @@ def add_query_param(request, key, val):
@register.filter @register.filter
def as_string(value): def as_string(value):
if value is None: if value is None:
return '' return ""
return '%s' % value return "%s" % value
@register.filter @register.filter
def as_list_of_strings(value): def as_list_of_strings(value):
return [ return ["" if (item is None) else ("%s" % item) for item in value]
'' if (item is None) else ('%s' % item)
for item in value
]
@register.filter @register.filter
@ -190,45 +183,52 @@ def add_class(value, css_class):
html = six.text_type(value) html = six.text_type(value)
match = class_re.search(html) match = class_re.search(html)
if match: if match:
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class, m = re.search(
css_class, css_class), r"^%s$|^%s\s|\s%s\s|\s%s$" % (css_class, css_class, css_class, css_class),
match.group(1)) match.group(1),
)
if not m: if not m:
return mark_safe(class_re.sub(match.group(1) + " " + css_class, return mark_safe(class_re.sub(match.group(1) + " " + css_class, html))
html))
else: else:
return mark_safe(html.replace('>', ' class="%s">' % css_class, 1)) return mark_safe(html.replace(">", ' class="%s">' % css_class, 1))
return value return value
@register.filter @register.filter
def format_value(value): def format_value(value):
if getattr(value, 'is_hyperlink', False): if getattr(value, "is_hyperlink", False):
name = six.text_type(value.obj) name = six.text_type(value.obj)
return mark_safe('<a href=%s>%s</a>' % (value, escape(name))) return mark_safe("<a href=%s>%s</a>" % (value, escape(name)))
if value is None or isinstance(value, bool): if value is None or isinstance(value, bool):
return mark_safe('<code>%s</code>' % {True: 'true', False: 'false', None: 'null'}[value]) return mark_safe(
"<code>%s</code>" % {True: "true", False: "false", None: "null"}[value]
)
elif isinstance(value, list): elif isinstance(value, list):
if any([isinstance(item, (list, dict)) for item in value]): if any([isinstance(item, (list, dict)) for item in value]):
template = loader.get_template('rest_framework/admin/list_value.html') template = loader.get_template("rest_framework/admin/list_value.html")
else: else:
template = loader.get_template('rest_framework/admin/simple_list_value.html') template = loader.get_template(
context = {'value': value} "rest_framework/admin/simple_list_value.html"
)
context = {"value": value}
return template.render(context) return template.render(context)
elif isinstance(value, dict): elif isinstance(value, dict):
template = loader.get_template('rest_framework/admin/dict_value.html') template = loader.get_template("rest_framework/admin/dict_value.html")
context = {'value': value} context = {"value": value}
return template.render(context) return template.render(context)
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
if ( if (value.startswith("http:") or value.startswith("https:")) and not re.search(
(value.startswith('http:') or value.startswith('https:')) and not r"\s", value
re.search(r'\s', value)
): ):
return mark_safe('<a href="{value}">{value}</a>'.format(value=escape(value))) return mark_safe(
elif '@' in value and not re.search(r'\s', value): '<a href="{value}">{value}</a>'.format(value=escape(value))
return mark_safe('<a href="mailto:{value}">{value}</a>'.format(value=escape(value))) )
elif '\n' in value: elif "@" in value and not re.search(r"\s", value):
return mark_safe('<pre>%s</pre>' % escape(value)) return mark_safe(
'<a href="mailto:{value}">{value}</a>'.format(value=escape(value))
)
elif "\n" in value:
return mark_safe("<pre>%s</pre>" % escape(value))
return six.text_type(value) return six.text_type(value)
@ -266,7 +266,7 @@ def schema_links(section, sec_key=None):
""" """
Recursively find every link in a schema, even nested. Recursively find every link in a schema, even nested.
""" """
NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys NESTED_FORMAT = "%s > %s" # this format is used in docs/js/api.js:normalizeKeys
links = section.links links = section.links
if section.data: if section.data:
data = section.data.items() data = section.data.items()
@ -287,20 +287,30 @@ def schema_links(section, sec_key=None):
@register.filter @register.filter
def add_nested_class(value): def add_nested_class(value):
if isinstance(value, dict): if isinstance(value, dict):
return 'class=nested' return "class=nested"
if isinstance(value, list) and any([isinstance(item, (list, dict)) for item in value]): if isinstance(value, list) and any(
return 'class=nested' [isinstance(item, (list, dict)) for item in value]
return '' ):
return "class=nested"
return ""
# Bunch of stuff cloned from urlize # Bunch of stuff cloned from urlize
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"] TRAILING_PUNCTUATION = [".", ",", ":", ";", ".)", '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'), WRAPPING_PUNCTUATION = [
('"', '"'), ("'", "'")] ("(", ")"),
word_split_re = re.compile(r'(\s+)') ("<", ">"),
simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) ("[", "]"),
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) ("&lt;", "&gt;"),
simple_email_re = re.compile(r'^\S+@\S+\.\S+$') ('"', '"'),
("'", "'"),
]
word_split_re = re.compile(r"(\s+)")
simple_url_re = re.compile(r"^https?://\[?\w", re.IGNORECASE)
simple_url_2_re = re.compile(
r"^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$", re.IGNORECASE
)
simple_email_re = re.compile(r"^\S+@\S+\.\S+$")
def smart_urlquote_wrapper(matched_url): def smart_urlquote_wrapper(matched_url):
@ -332,8 +342,13 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
If autoescape is True, the link text and URLs will get autoescaped. If autoescape is True, the link text and URLs will get autoescaped.
""" """
def trim_url(x, limit=trim_url_limit): def trim_url(x, limit=trim_url_limit):
return limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x return (
limit is not None
and (len(x) > limit and ("%s..." % x[: max(0, limit - 3)]))
or x
)
safe_input = isinstance(text, SafeData) safe_input = isinstance(text, SafeData)
@ -344,40 +359,40 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words = word_split_re.split(force_text(text)) words = word_split_re.split(force_text(text))
for i, word in enumerate(words): for i, word in enumerate(words):
if '.' in word or '@' in word or ':' in word: if "." in word or "@" in word or ":" in word:
# Deal with punctuation. # Deal with punctuation.
lead, middle, trail = '', word, '' lead, middle, trail = "", word, ""
for punctuation in TRAILING_PUNCTUATION: for punctuation in TRAILING_PUNCTUATION:
if middle.endswith(punctuation): if middle.endswith(punctuation):
middle = middle[:-len(punctuation)] middle = middle[: -len(punctuation)]
trail = punctuation + trail trail = punctuation + trail
for opening, closing in WRAPPING_PUNCTUATION: for opening, closing in WRAPPING_PUNCTUATION:
if middle.startswith(opening): if middle.startswith(opening):
middle = middle[len(opening):] middle = middle[len(opening) :]
lead = lead + opening lead = lead + opening
# Keep parentheses at the end only if they're balanced. # Keep parentheses at the end only if they're balanced.
if ( if (
middle.endswith(closing) and middle.endswith(closing)
middle.count(closing) == middle.count(opening) + 1 and middle.count(closing) == middle.count(opening) + 1
): ):
middle = middle[:-len(closing)] middle = middle[: -len(closing)]
trail = closing + trail trail = closing + trail
# Make URL we want to point to. # Make URL we want to point to.
url = None url = None
nofollow_attr = ' rel="nofollow"' if nofollow else '' nofollow_attr = ' rel="nofollow"' if nofollow else ""
if simple_url_re.match(middle): if simple_url_re.match(middle):
url = smart_urlquote_wrapper(middle) url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle): elif simple_url_2_re.match(middle):
url = smart_urlquote_wrapper('http://%s' % middle) url = smart_urlquote_wrapper("http://%s" % middle)
elif ':' not in middle and simple_email_re.match(middle): elif ":" not in middle and simple_email_re.match(middle):
local, domain = middle.rsplit('@', 1) local, domain = middle.rsplit("@", 1)
try: try:
domain = domain.encode('idna').decode('ascii') domain = domain.encode("idna").decode("ascii")
except UnicodeError: except UnicodeError:
continue continue
url = 'mailto:%s@%s' % (local, domain) url = "mailto:%s@%s" % (local, domain)
nofollow_attr = '' nofollow_attr = ""
# Make link. # Make link.
if url: if url:
@ -385,12 +400,12 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
lead, trail = conditional_escape(lead), conditional_escape(trail) lead, trail = conditional_escape(lead), conditional_escape(trail)
url, trimmed = conditional_escape(url), conditional_escape(trimmed) url, trimmed = conditional_escape(url), conditional_escape(trimmed)
middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed) middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed)
words[i] = '%s%s%s' % (lead, middle, trail) words[i] = "%s%s%s" % (lead, middle, trail)
else: else:
words[i] = conditional_escape(word) words[i] = conditional_escape(word)
else: else:
words[i] = conditional_escape(word) words[i] = conditional_escape(word)
return mark_safe(''.join(words)) return mark_safe("".join(words))
@register.filter @register.filter
@ -399,6 +414,6 @@ def break_long_headers(header):
Breaks headers longer than 160 characters (~page length) Breaks headers longer than 160 characters (~page length)
when possible (are comma separated) when possible (are comma separated)
""" """
if len(header) > 160 and ',' in header: if len(header) > 160 and "," in header:
header = mark_safe('<br> ' + ', <br>'.join(header.split(','))) header = mark_safe("<br> " + ", <br>".join(header.split(",")))
return header return header

View File

@ -11,9 +11,11 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler from django.core.handlers.wsgi import WSGIHandler
from django.test import override_settings, testcases from django.test import override_settings, testcases
from django.test.client import Client as DjangoClient from django.test.client import (
from django.test.client import ClientHandler Client as DjangoClient,
from django.test.client import RequestFactory as DjangoRequestFactory ClientHandler,
RequestFactory as DjangoRequestFactory,
)
from django.utils import six from django.utils import six
from django.utils.encoding import force_bytes from django.utils.encoding import force_bytes
from django.utils.http import urlencode from django.utils.http import urlencode
@ -28,6 +30,7 @@ def force_authenticate(request, user=None, token=None):
if requests is not None: if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default): def get_all(self, key, default):
return self.getheaders(key) return self.getheaders(key)
@ -48,6 +51,7 @@ if requests is not None:
A transport adapter for `requests`, that makes requests via the A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network. Django WSGI app, rather than making actual HTTP requests over the network.
""" """
def __init__(self): def __init__(self):
self.app = WSGIHandler() self.app = WSGIHandler()
self.factory = DjangoRequestFactory() self.factory = DjangoRequestFactory()
@ -62,19 +66,19 @@ if requests is not None:
# Set request content, if any exists. # Set request content, if any exists.
if request.body is not None: if request.body is not None:
if hasattr(request.body, 'read'): if hasattr(request.body, "read"):
kwargs['data'] = request.body.read() kwargs["data"] = request.body.read()
else: else:
kwargs['data'] = request.body kwargs["data"] = request.body
if 'content-type' in request.headers: if "content-type" in request.headers:
kwargs['content_type'] = request.headers['content-type'] kwargs["content_type"] = request.headers["content-type"]
# Set request headers. # Set request headers.
for key, value in request.headers.items(): for key, value in request.headers.items():
key = key.upper() key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): if key in ("CONNECTION", "CONTENT-LENGTH", "CONTENT-TYPE"):
continue continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value kwargs["HTTP_%s" % key.replace("-", "_")] = value
return self.factory.generic(method, url, **kwargs).environ return self.factory.generic(method, url, **kwargs).environ
@ -85,20 +89,20 @@ if requests is not None:
raw_kwargs = {} raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers): def start_response(wsgi_status, wsgi_headers):
status, _, reason = wsgi_status.partition(' ') status, _, reason = wsgi_status.partition(" ")
raw_kwargs['status'] = int(status) raw_kwargs["status"] = int(status)
raw_kwargs['reason'] = reason raw_kwargs["reason"] = reason
raw_kwargs['headers'] = wsgi_headers raw_kwargs["headers"] = wsgi_headers
raw_kwargs['version'] = 11 raw_kwargs["version"] = 11
raw_kwargs['preload_content'] = False raw_kwargs["preload_content"] = False
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) raw_kwargs["original_response"] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI. # Make the outgoing request via WSGI.
environ = self.get_environ(request) environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response) wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse # Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) raw_kwargs["body"] = io.BytesIO(b"".join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response # Build the requests.Response
@ -111,33 +115,47 @@ if requests is not None:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(RequestsClient, self).__init__(*args, **kwargs) super(RequestsClient, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter() adapter = DjangoTestAdapter()
self.mount('http://', adapter) self.mount("http://", adapter)
self.mount('https://', adapter) self.mount("https://", adapter)
def request(self, method, url, *args, **kwargs): def request(self, method, url, *args, **kwargs):
if not url.startswith('http'): if not url.startswith("http"):
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) raise ValueError(
'Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"'
% url
)
return super(RequestsClient, self).request(method, url, *args, **kwargs) return super(RequestsClient, self).request(method, url, *args, **kwargs)
else: else:
def RequestsClient(*args, **kwargs): def RequestsClient(*args, **kwargs):
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') raise ImproperlyConfigured(
"requests must be installed in order to use RequestsClient."
)
if coreapi is not None: if coreapi is not None:
class CoreAPIClient(coreapi.Client): class CoreAPIClient(coreapi.Client):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self._session = RequestsClient() self._session = RequestsClient()
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] kwargs["transports"] = [
coreapi.transports.HTTPTransport(session=self.session)
]
return super(CoreAPIClient, self).__init__(*args, **kwargs) return super(CoreAPIClient, self).__init__(*args, **kwargs)
@property @property
def session(self): def session(self):
return self._session return self._session
else: else:
def CoreAPIClient(*args, **kwargs): def CoreAPIClient(*args, **kwargs):
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') raise ImproperlyConfigured(
"coreapi must be installed in order to use CoreAPIClient."
)
class APIRequestFactory(DjangoRequestFactory): class APIRequestFactory(DjangoRequestFactory):
@ -157,11 +175,11 @@ class APIRequestFactory(DjangoRequestFactory):
""" """
if data is None: if data is None:
return ('', content_type) return ("", content_type)
assert format is None or content_type is None, ( assert (
'You may not set both `format` and `content_type`.' format is None or content_type is None
) ), "You may not set both `format` and `content_type`."
if content_type: if content_type:
# Content type specified explicitly, treat data as a raw bytestring # Content type specified explicitly, treat data as a raw bytestring
@ -175,7 +193,7 @@ class APIRequestFactory(DjangoRequestFactory):
"Set TEST_REQUEST_RENDERER_CLASSES to enable " "Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format( "extra request formats.".format(
format, format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes]) ", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]),
) )
) )
@ -195,47 +213,53 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type return ret, content_type
def get(self, path, data=None, **extra): def get(self, path, data=None, **extra):
r = { r = {"QUERY_STRING": urlencode(data or {}, doseq=True)}
'QUERY_STRING': urlencode(data or {}, doseq=True), if not data and "?" in path:
}
if not data and '?' in path:
# Fix to support old behavior where you have the arguments in the # Fix to support old behavior where you have the arguments in the
# url. See #1461. # url. See #1461.
query_string = force_bytes(path.split('?')[1]) query_string = force_bytes(path.split("?")[1])
if six.PY3: if six.PY3:
query_string = query_string.decode('iso-8859-1') query_string = query_string.decode("iso-8859-1")
r['QUERY_STRING'] = query_string r["QUERY_STRING"] = query_string
r.update(extra) r.update(extra)
return self.generic('GET', path, **r) return self.generic("GET", path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra): def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra) return self.generic("POST", path, data, content_type, **extra)
def put(self, path, data=None, format=None, content_type=None, **extra): def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('PUT', path, data, content_type, **extra) return self.generic("PUT", path, data, content_type, **extra)
def patch(self, path, data=None, format=None, content_type=None, **extra): def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra) return self.generic("PATCH", path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra): def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra) return self.generic("DELETE", path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra): def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra) return self.generic("OPTIONS", path, data, content_type, **extra)
def generic(self, method, path, data='', def generic(
content_type='application/octet-stream', secure=False, **extra): self,
method,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra
):
# Include the CONTENT_TYPE, regardless of whether or not data is empty. # Include the CONTENT_TYPE, regardless of whether or not data is empty.
if content_type is not None: if content_type is not None:
extra['CONTENT_TYPE'] = str(content_type) extra["CONTENT_TYPE"] = str(content_type)
return super(APIRequestFactory, self).generic( return super(APIRequestFactory, self).generic(
method, path, data, content_type, secure, **extra) method, path, data, content_type, secure, **extra
)
def request(self, **kwargs): def request(self, **kwargs):
request = super(APIRequestFactory, self).request(**kwargs) request = super(APIRequestFactory, self).request(**kwargs)
@ -294,42 +318,52 @@ class APIClient(APIRequestFactory, DjangoClient):
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def post(self, path, data=None, format=None, content_type=None, def post(
follow=False, **extra): self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).post( response = super(APIClient, self).post(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra
)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def put(self, path, data=None, format=None, content_type=None, def put(
follow=False, **extra): self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).put( response = super(APIClient, self).put(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra
)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def patch(self, path, data=None, format=None, content_type=None, def patch(
follow=False, **extra): self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).patch( response = super(APIClient, self).patch(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra
)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def delete(self, path, data=None, format=None, content_type=None, def delete(
follow=False, **extra): self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).delete( response = super(APIClient, self).delete(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra
)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
def options(self, path, data=None, format=None, content_type=None, def options(
follow=False, **extra): self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).options( response = super(APIClient, self).options(
path, data=data, format=format, content_type=content_type, **extra) path, data=data, format=format, content_type=content_type, **extra
)
if follow: if follow:
response = self._handle_redirects(response, **extra) response = self._handle_redirects(response, **extra)
return response return response
@ -377,13 +411,14 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
def test_something_else(self): def test_something_else(self):
... ...
""" """
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# Get the module of the TestCase subclass # Get the module of the TestCase subclass
cls._module = import_module(cls.__module__) cls._module = import_module(cls.__module__)
cls._override = override_settings(ROOT_URLCONF=cls.__module__) cls._override = override_settings(ROOT_URLCONF=cls.__module__)
if hasattr(cls._module, 'urlpatterns'): if hasattr(cls._module, "urlpatterns"):
cls._module_urlpatterns = cls._module.urlpatterns cls._module_urlpatterns = cls._module.urlpatterns
cls._module.urlpatterns = cls.urlpatterns cls._module.urlpatterns = cls.urlpatterns
@ -396,7 +431,7 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
super(URLPatternsTestCase, cls).tearDownClass() super(URLPatternsTestCase, cls).tearDownClass()
cls._override.disable() cls._override.disable()
if hasattr(cls, '_module_urlpatterns'): if hasattr(cls, "_module_urlpatterns"):
cls._module.urlpatterns = cls._module_urlpatterns cls._module.urlpatterns = cls._module_urlpatterns
else: else:
del cls._module.urlpatterns del cls._module.urlpatterns

View File

@ -20,7 +20,7 @@ class BaseThrottle(object):
""" """
Return `True` if the request should be allowed, `False` otherwise. Return `True` if the request should be allowed, `False` otherwise.
""" """
raise NotImplementedError('.allow_request() must be overridden') raise NotImplementedError(".allow_request() must be overridden")
def get_ident(self, request): def get_ident(self, request):
""" """
@ -28,18 +28,18 @@ class BaseThrottle(object):
if present and number of proxies is > 0. If not use all of if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
""" """
xff = request.META.get('HTTP_X_FORWARDED_FOR') xff = request.META.get("HTTP_X_FORWARDED_FOR")
remote_addr = request.META.get('REMOTE_ADDR') remote_addr = request.META.get("REMOTE_ADDR")
num_proxies = api_settings.NUM_PROXIES num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None: if num_proxies is not None:
if num_proxies == 0 or xff is None: if num_proxies == 0 or xff is None:
return remote_addr return remote_addr
addrs = xff.split(',') addrs = xff.split(",")
client_addr = addrs[-min(num_proxies, len(addrs))] client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip() return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr return "".join(xff.split()) if xff else remote_addr
def wait(self): def wait(self):
""" """
@ -61,14 +61,15 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache. Previous request information used for throttling is stored in the cache.
""" """
cache = default_cache cache = default_cache
timer = time.time timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s' cache_format = "throttle_%(scope)s_%(ident)s"
scope = None scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self): def __init__(self):
if not getattr(self, 'rate', None): if not getattr(self, "rate", None):
self.rate = self.get_rate() self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate) self.num_requests, self.duration = self.parse_rate(self.rate)
@ -79,15 +80,17 @@ class SimpleRateThrottle(BaseThrottle):
May return `None` if the request should not be throttled. May return `None` if the request should not be throttled.
""" """
raise NotImplementedError('.get_cache_key() must be overridden') raise NotImplementedError(".get_cache_key() must be overridden")
def get_rate(self): def get_rate(self):
""" """
Determine the string representation of the allowed request rate. Determine the string representation of the allowed request rate.
""" """
if not getattr(self, 'scope', None): if not getattr(self, "scope", None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % msg = (
self.__class__.__name__) "You must set either `.scope` or `.rate` for '%s' throttle"
% self.__class__.__name__
)
raise ImproperlyConfigured(msg) raise ImproperlyConfigured(msg)
try: try:
@ -103,9 +106,9 @@ class SimpleRateThrottle(BaseThrottle):
""" """
if rate is None: if rate is None:
return (None, None) return (None, None)
num, period = rate.split('/') num, period = rate.split("/")
num_requests = int(num) num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]]
return (num_requests, duration) return (num_requests, duration)
def allow_request(self, request, view): def allow_request(self, request, view):
@ -170,15 +173,16 @@ class AnonRateThrottle(SimpleRateThrottle):
The IP address of the request will be used as the unique cache key. The IP address of the request will be used as the unique cache key.
""" """
scope = 'anon'
scope = "anon"
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if request.user.is_authenticated: if request.user.is_authenticated:
return None # Only throttle unauthenticated requests. return None # Only throttle unauthenticated requests.
return self.cache_format % { return self.cache_format % {
'scope': self.scope, "scope": self.scope,
'ident': self.get_ident(request) "ident": self.get_ident(request),
} }
@ -190,7 +194,8 @@ class UserRateThrottle(SimpleRateThrottle):
authenticated. For anonymous requests, the IP address of the request will authenticated. For anonymous requests, the IP address of the request will
be used. be used.
""" """
scope = 'user'
scope = "user"
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if request.user.is_authenticated: if request.user.is_authenticated:
@ -198,10 +203,7 @@ class UserRateThrottle(SimpleRateThrottle):
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)
return self.cache_format % { return self.cache_format % {"scope": self.scope, "ident": ident}
'scope': self.scope,
'ident': ident
}
class ScopedRateThrottle(SimpleRateThrottle): class ScopedRateThrottle(SimpleRateThrottle):
@ -211,7 +213,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
throttled. The unique cache key will be generated by concatenating the throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed. user id of the request, and the scope of the view being accessed.
""" """
scope_attr = 'throttle_scope'
scope_attr = "throttle_scope"
def __init__(self): def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine # Override the usual SimpleRateThrottle, because we can't determine
@ -246,7 +249,4 @@ class ScopedRateThrottle(SimpleRateThrottle):
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)
return self.cache_format % { return self.cache_format % {"scope": self.scope, "ident": ident}
'scope': self.scope,
'ident': ident
}

View File

@ -3,7 +3,11 @@ from __future__ import unicode_literals
from django.conf.urls import include, url from django.conf.urls import include, url
from rest_framework.compat import ( from rest_framework.compat import (
URLResolver, get_regex_pattern, is_route_pattern, path, register_converter URLResolver,
get_regex_pattern,
is_route_pattern,
path,
register_converter,
) )
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -13,7 +17,7 @@ def _get_format_path_converter(suffix_kwarg, allowed):
if len(allowed) == 1: if len(allowed) == 1:
allowed_pattern = allowed[0] allowed_pattern = allowed[0]
else: else:
allowed_pattern = '(?:%s)' % '|'.join(allowed) allowed_pattern = "(?:%s)" % "|".join(allowed)
suffix_pattern = r"\.%s/?" % allowed_pattern suffix_pattern = r"\.%s/?" % allowed_pattern
else: else:
suffix_pattern = r"\.[a-z0-9]+/?" suffix_pattern = r"\.[a-z0-9]+/?"
@ -22,19 +26,21 @@ def _get_format_path_converter(suffix_kwarg, allowed):
regex = suffix_pattern regex = suffix_pattern
def to_python(self, value): def to_python(self, value):
return value.strip('./') return value.strip("./")
def to_url(self, value): def to_url(self, value):
return '.' + value + '/' return "." + value + "/"
converter_name = 'drf_format_suffix' converter_name = "drf_format_suffix"
if allowed: if allowed:
converter_name += '_' + '_'.join(allowed) converter_name += "_" + "_".join(allowed)
return converter_name, FormatSuffixConverter return converter_name, FormatSuffixConverter
def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route=None): def apply_suffix_patterns(
urlpatterns, suffix_pattern, suffix_required, suffix_route=None
):
ret = [] ret = []
for urlpattern in urlpatterns: for urlpattern in urlpatterns:
if isinstance(urlpattern, URLResolver): if isinstance(urlpattern, URLResolver):
@ -44,23 +50,28 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
app_name = urlpattern.app_name app_name = urlpattern.app_name
kwargs = urlpattern.default_kwargs kwargs = urlpattern.default_kwargs
# Add in the included patterns, after applying the suffixes # Add in the included patterns, after applying the suffixes
patterns = apply_suffix_patterns(urlpattern.url_patterns, patterns = apply_suffix_patterns(
suffix_pattern, urlpattern.url_patterns, suffix_pattern, suffix_required, suffix_route
suffix_required, )
suffix_route)
# if the original pattern was a RoutePattern we need to preserve it # if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern): if is_route_pattern(urlpattern):
assert path is not None assert path is not None
route = str(urlpattern.pattern) route = str(urlpattern.pattern)
new_pattern = path(route, include((patterns, app_name), namespace), kwargs) new_pattern = path(
route, include((patterns, app_name), namespace), kwargs
)
else: else:
new_pattern = url(regex, include((patterns, app_name), namespace), kwargs) new_pattern = url(
regex, include((patterns, app_name), namespace), kwargs
)
ret.append(new_pattern) ret.append(new_pattern)
else: else:
# Regular URL pattern # Regular URL pattern
regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern regex = (
get_regex_pattern(urlpattern).rstrip("$").rstrip("/") + suffix_pattern
)
view = urlpattern.callback view = urlpattern.callback
kwargs = urlpattern.default_args kwargs = urlpattern.default_args
name = urlpattern.name name = urlpattern.name
@ -72,7 +83,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
if is_route_pattern(urlpattern): if is_route_pattern(urlpattern):
assert path is not None assert path is not None
assert suffix_route is not None assert suffix_route is not None
route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route route = str(urlpattern.pattern).rstrip("$").rstrip("/") + suffix_route
new_pattern = path(route, view, kwargs, name) new_pattern = path(route, view, kwargs, name)
else: else:
new_pattern = url(regex, view, kwargs, name) new_pattern = url(regex, view, kwargs, name)
@ -103,17 +114,21 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
if len(allowed) == 1: if len(allowed) == 1:
allowed_pattern = allowed[0] allowed_pattern = allowed[0]
else: else:
allowed_pattern = '(%s)' % '|'.join(allowed) allowed_pattern = "(%s)" % "|".join(allowed)
suffix_pattern = r'\.(?P<%s>%s)/?$' % (suffix_kwarg, allowed_pattern) suffix_pattern = r"\.(?P<%s>%s)/?$" % (suffix_kwarg, allowed_pattern)
else: else:
suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg suffix_pattern = r"\.(?P<%s>[a-z0-9]+)/?$" % suffix_kwarg
if path and register_converter: if path and register_converter:
converter_name, suffix_converter = _get_format_path_converter(suffix_kwarg, allowed) converter_name, suffix_converter = _get_format_path_converter(
suffix_kwarg, allowed
)
register_converter(suffix_converter, converter_name) register_converter(suffix_converter, converter_name)
suffix_route = '<%s:%s>' % (converter_name, suffix_kwarg) suffix_route = "<%s:%s>" % (converter_name, suffix_kwarg)
else: else:
suffix_route = None suffix_route = None
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route) return apply_suffix_patterns(
urlpatterns, suffix_pattern, suffix_required, suffix_route
)

View File

@ -16,8 +16,13 @@ from __future__ import unicode_literals
from django.conf.urls import url from django.conf.urls import url
from django.contrib.auth import views from django.contrib.auth import views
app_name = 'rest_framework'
app_name = "rest_framework"
urlpatterns = [ urlpatterns = [
url(r'^login/$', views.LoginView.as_view(template_name='rest_framework/login.html'), name='login'), url(
url(r'^logout/$', views.LogoutView.as_view(), name='logout'), r"^login/$",
views.LoginView.as_view(template_name="rest_framework/login.html"),
name="login",
),
url(r"^logout/$", views.LogoutView.as_view(), name="logout"),
] ]

View File

@ -23,8 +23,8 @@ def get_breadcrumbs(url, request=None):
else: else:
# Check if this is a REST framework view, # Check if this is a REST framework view,
# and if so add it to the breadcrumbs # and if so add it to the breadcrumbs
cls = getattr(view, 'cls', None) cls = getattr(view, "cls", None)
initkwargs = getattr(view, 'initkwargs', {}) initkwargs = getattr(view, "initkwargs", {})
if cls is not None and issubclass(cls, APIView): if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row. # Don't list the same view twice in a row.
# Probably an optional trailing slash. # Probably an optional trailing slash.
@ -35,21 +35,21 @@ def get_breadcrumbs(url, request=None):
breadcrumbs_list.insert(0, (name, insert_url)) breadcrumbs_list.insert(0, (name, insert_url))
seen.append(view) seen.append(view)
if url == '': if url == "":
# All done # All done
return breadcrumbs_list return breadcrumbs_list
elif url.endswith('/'): elif url.endswith("/"):
# Drop trailing slash off the end and continue to try to # Drop trailing slash off the end and continue to try to
# resolve more breadcrumbs # resolve more breadcrumbs
url = url.rstrip('/') url = url.rstrip("/")
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to # Drop trailing non-slash off the end and continue to try to
# resolve more breadcrumbs # resolve more breadcrumbs
url = url[:url.rfind('/') + 1] url = url[: url.rfind("/") + 1]
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/') prefix = get_script_prefix().rstrip("/")
url = url[len(prefix):] url = url[len(prefix) :]
return breadcrumbs_recursive(url, [], prefix, []) return breadcrumbs_recursive(url, [], prefix, [])

View File

@ -21,6 +21,7 @@ class JSONEncoder(json.JSONEncoder):
JSONEncoder subclass that knows how to encode date/time/timedelta, JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, generators and other basic python objects. decimal types, generators and other basic python objects.
""" """
def default(self, obj): def default(self, obj):
# For Date Time string spec, see ECMA 262 # For Date Time string spec, see ECMA 262
# https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 # https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
@ -28,8 +29,8 @@ class JSONEncoder(json.JSONEncoder):
return force_text(obj) return force_text(obj)
elif isinstance(obj, datetime.datetime): elif isinstance(obj, datetime.datetime):
representation = obj.isoformat() representation = obj.isoformat()
if representation.endswith('+00:00'): if representation.endswith("+00:00"):
representation = representation[:-6] + 'Z' representation = representation[:-6] + "Z"
return representation return representation
elif isinstance(obj, datetime.date): elif isinstance(obj, datetime.date):
return obj.isoformat() return obj.isoformat()
@ -49,20 +50,22 @@ class JSONEncoder(json.JSONEncoder):
return tuple(obj) return tuple(obj)
elif isinstance(obj, bytes): elif isinstance(obj, bytes):
# Best-effort for binary blobs. See #4187. # Best-effort for binary blobs. See #4187.
return obj.decode('utf-8') return obj.decode("utf-8")
elif hasattr(obj, 'tolist'): elif hasattr(obj, "tolist"):
# Numpy arrays and array scalars. # Numpy arrays and array scalars.
return obj.tolist() return obj.tolist()
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): elif (coreapi is not None) and isinstance(
obj, (coreapi.Document, coreapi.Error)
):
raise RuntimeError( raise RuntimeError(
'Cannot return a coreapi object from a JSON view. ' "Cannot return a coreapi object from a JSON view. "
'You should be using a schema renderer instead for this view.' "You should be using a schema renderer instead for this view."
) )
elif hasattr(obj, '__getitem__'): elif hasattr(obj, "__getitem__"):
try: try:
return dict(obj) return dict(obj)
except Exception: except Exception:
pass pass
elif hasattr(obj, '__iter__'): elif hasattr(obj, "__iter__"):
return tuple(item for item in obj) return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj) return super(JSONEncoder, self).default(obj)

View File

@ -11,8 +11,12 @@ from django.utils.text import capfirst
from rest_framework.compat import postgres_fields from rest_framework.compat import postgres_fields
from rest_framework.validators import UniqueValidator from rest_framework.validators import UniqueValidator
NUMERIC_FIELD_TYPES = ( NUMERIC_FIELD_TYPES = (
models.IntegerField, models.FloatField, models.DecimalField, models.DurationField, models.IntegerField,
models.FloatField,
models.DecimalField,
models.DurationField,
) )
@ -23,11 +27,12 @@ class ClassLookupDict(object):
hierarchy in method resolution order, and returns the first matching value hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches. from the dictionary or raises a KeyError if nothing matches.
""" """
def __init__(self, mapping): def __init__(self, mapping):
self.mapping = mapping self.mapping = mapping
def __getitem__(self, key): def __getitem__(self, key):
if hasattr(key, '_proxy_class'): if hasattr(key, "_proxy_class"):
# Deal with proxy classes. Ie. BoundField behaves as if it # Deal with proxy classes. Ie. BoundField behaves as if it
# is a Field instance when using ClassLookupDict. # is a Field instance when using ClassLookupDict.
base_class = key._proxy_class base_class = key._proxy_class
@ -37,7 +42,7 @@ class ClassLookupDict(object):
for cls in inspect.getmro(base_class): for cls in inspect.getmro(base_class):
if cls in self.mapping: if cls in self.mapping:
return self.mapping[cls] return self.mapping[cls]
raise KeyError('Class %s not found in lookup.' % base_class.__name__) raise KeyError("Class %s not found in lookup." % base_class.__name__)
def __setitem__(self, key, value): def __setitem__(self, key, value):
self.mapping[key] = value self.mapping[key] = value
@ -48,7 +53,7 @@ def needs_label(model_field, field_name):
Returns `True` if the label based on the model's verbose name Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name. is not equal to the default label it would have based on it's field name.
""" """
default_label = field_name.replace('_', ' ').capitalize() default_label = field_name.replace("_", " ").capitalize()
return capfirst(model_field.verbose_name) != default_label return capfirst(model_field.verbose_name) != default_label
@ -57,9 +62,9 @@ def get_detail_view_name(model):
Given a model class, return the view name to use for URL relationships Given a model class, return the view name to use for URL relationships
that refer to instances of the model. that refer to instances of the model.
""" """
return '%(model_name)s-detail' % { return "%(model_name)s-detail" % {
'app_label': model._meta.app_label, "app_label": model._meta.app_label,
'model_name': model._meta.object_name.lower() "model_name": model._meta.object_name.lower(),
} }
@ -72,84 +77,98 @@ def get_field_kwargs(field_name, model_field):
# The following will only be used by ModelField classes. # The following will only be used by ModelField classes.
# Gets removed for everything else. # Gets removed for everything else.
kwargs['model_field'] = model_field kwargs["model_field"] = model_field
if model_field.verbose_name and needs_label(model_field, field_name): if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name) kwargs["label"] = capfirst(model_field.verbose_name)
if model_field.help_text: if model_field.help_text:
kwargs['help_text'] = model_field.help_text kwargs["help_text"] = model_field.help_text
max_digits = getattr(model_field, 'max_digits', None) max_digits = getattr(model_field, "max_digits", None)
if max_digits is not None: if max_digits is not None:
kwargs['max_digits'] = max_digits kwargs["max_digits"] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None) decimal_places = getattr(model_field, "decimal_places", None)
if decimal_places is not None: if decimal_places is not None:
kwargs['decimal_places'] = decimal_places kwargs["decimal_places"] = decimal_places
if isinstance(model_field, models.SlugField): if isinstance(model_field, models.SlugField):
kwargs['allow_unicode'] = model_field.allow_unicode kwargs["allow_unicode"] = model_field.allow_unicode
if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)): if isinstance(model_field, models.TextField) or (
kwargs['style'] = {'base_template': 'textarea.html'} postgres_fields and isinstance(model_field, postgres_fields.JSONField)
):
kwargs["style"] = {"base_template": "textarea.html"}
if isinstance(model_field, models.AutoField) or not model_field.editable: if isinstance(model_field, models.AutoField) or not model_field.editable:
# If this field is read-only, then return early. # If this field is read-only, then return early.
# Further keyword arguments are not valid. # Further keyword arguments are not valid.
kwargs['read_only'] = True kwargs["read_only"] = True
return kwargs return kwargs
if model_field.has_default() or model_field.blank or model_field.null: if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False kwargs["required"] = False
if model_field.null and not isinstance(model_field, models.NullBooleanField): if model_field.null and not isinstance(model_field, models.NullBooleanField):
kwargs['allow_null'] = True kwargs["allow_null"] = True
if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))): if model_field.blank and (
kwargs['allow_blank'] = True isinstance(model_field, (models.CharField, models.TextField))
):
kwargs["allow_blank"] = True
if isinstance(model_field, models.FilePathField): if isinstance(model_field, models.FilePathField):
kwargs['path'] = model_field.path kwargs["path"] = model_field.path
if model_field.match is not None: if model_field.match is not None:
kwargs['match'] = model_field.match kwargs["match"] = model_field.match
if model_field.recursive is not False: if model_field.recursive is not False:
kwargs['recursive'] = model_field.recursive kwargs["recursive"] = model_field.recursive
if model_field.allow_files is not True: if model_field.allow_files is not True:
kwargs['allow_files'] = model_field.allow_files kwargs["allow_files"] = model_field.allow_files
if model_field.allow_folders is not False: if model_field.allow_folders is not False:
kwargs['allow_folders'] = model_field.allow_folders kwargs["allow_folders"] = model_field.allow_folders
if model_field.choices: if model_field.choices:
kwargs['choices'] = model_field.choices kwargs["choices"] = model_field.choices
else: else:
# Ensure that max_value is passed explicitly as a keyword arg, # Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator. # rather than as a validator.
max_value = next(( max_value = next(
validator.limit_value for validator in validator_kwarg (
if isinstance(validator, validators.MaxValueValidator) validator.limit_value
), None) for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
),
None,
)
if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['max_value'] = max_value kwargs["max_value"] = max_value
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator) if not isinstance(validator, validators.MaxValueValidator)
] ]
# Ensure that min_value is passed explicitly as a keyword arg, # Ensure that min_value is passed explicitly as a keyword arg,
# rather than as a validator. # rather than as a validator.
min_value = next(( min_value = next(
validator.limit_value for validator in validator_kwarg (
if isinstance(validator, validators.MinValueValidator) validator.limit_value
), None) for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
),
None,
)
if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['min_value'] = min_value kwargs["min_value"] = min_value
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator) if not isinstance(validator, validators.MinValueValidator)
] ]
@ -157,7 +176,8 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in. # as it is explicitly added in.
if isinstance(model_field, models.URLField): if isinstance(model_field, models.URLField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator) if not isinstance(validator, validators.URLValidator)
] ]
@ -165,67 +185,79 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in. # as it is explicitly added in.
if isinstance(model_field, models.EmailField): if isinstance(model_field, models.EmailField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if validator is not validators.validate_email if validator is not validators.validate_email
] ]
# SlugField do not need to include the 'validate_slug' argument, # SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField): if isinstance(model_field, models.SlugField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if validator is not validators.validate_slug if validator is not validators.validate_slug
] ]
# IPAddressField do not need to include the 'validate_ipv46_address' argument, # IPAddressField do not need to include the 'validate_ipv46_address' argument,
if isinstance(model_field, models.GenericIPAddressField): if isinstance(model_field, models.GenericIPAddressField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if validator is not validators.validate_ipv46_address if validator is not validators.validate_ipv46_address
] ]
# Our decimal validation is handled in the field code, not validator code. # Our decimal validation is handled in the field code, not validator code.
if isinstance(model_field, models.DecimalField): if isinstance(model_field, models.DecimalField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.DecimalValidator) if not isinstance(validator, validators.DecimalValidator)
] ]
# Ensure that max_length is passed explicitly as a keyword arg, # Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator. # rather than as a validator.
max_length = getattr(model_field, 'max_length', None) max_length = getattr(model_field, "max_length", None)
if max_length is not None and (isinstance(model_field, (models.CharField, models.TextField, models.FileField))): if max_length is not None and (
kwargs['max_length'] = max_length isinstance(model_field, (models.CharField, models.TextField, models.FileField))
):
kwargs["max_length"] = max_length
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator) if not isinstance(validator, validators.MaxLengthValidator)
] ]
# Ensure that min_length is passed explicitly as a keyword arg, # Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator. # rather than as a validator.
min_length = next(( min_length = next(
validator.limit_value for validator in validator_kwarg (
if isinstance(validator, validators.MinLengthValidator) validator.limit_value
), None) for validator in validator_kwarg
if isinstance(validator, validators.MinLengthValidator)
),
None,
)
if min_length is not None and isinstance(model_field, models.CharField): if min_length is not None and isinstance(model_field, models.CharField):
kwargs['min_length'] = min_length kwargs["min_length"] = min_length
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator
for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator) if not isinstance(validator, validators.MinLengthValidator)
] ]
if getattr(model_field, 'unique', False): if getattr(model_field, "unique", False):
unique_error_message = model_field.error_messages.get('unique', None) unique_error_message = model_field.error_messages.get("unique", None)
if unique_error_message: if unique_error_message:
unique_error_message = unique_error_message % { unique_error_message = unique_error_message % {
'model_name': model_field.model._meta.verbose_name, "model_name": model_field.model._meta.verbose_name,
'field_label': model_field.verbose_name "field_label": model_field.verbose_name,
} }
validator = UniqueValidator( validator = UniqueValidator(
queryset=model_field.model._default_manager, queryset=model_field.model._default_manager, message=unique_error_message
message=unique_error_message) )
validator_kwarg.append(validator) validator_kwarg.append(validator)
if validator_kwarg: if validator_kwarg:
kwargs['validators'] = validator_kwarg kwargs["validators"] = validator_kwarg
return kwargs return kwargs
@ -234,65 +266,65 @@ def get_relation_kwargs(field_name, relation_info):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
model_field, related_model, to_many, to_field, has_through_model, reverse = relation_info model_field, related_model, to_many, to_field, has_through_model, reverse = (
relation_info
)
kwargs = { kwargs = {
'queryset': related_model._default_manager, "queryset": related_model._default_manager,
'view_name': get_detail_view_name(related_model) "view_name": get_detail_view_name(related_model),
} }
if to_many: if to_many:
kwargs['many'] = True kwargs["many"] = True
if to_field: if to_field:
kwargs['to_field'] = to_field kwargs["to_field"] = to_field
limit_choices_to = model_field and model_field.get_limit_choices_to() limit_choices_to = model_field and model_field.get_limit_choices_to()
if limit_choices_to: if limit_choices_to:
if not isinstance(limit_choices_to, models.Q): if not isinstance(limit_choices_to, models.Q):
limit_choices_to = models.Q(**limit_choices_to) limit_choices_to = models.Q(**limit_choices_to)
kwargs['queryset'] = kwargs['queryset'].filter(limit_choices_to) kwargs["queryset"] = kwargs["queryset"].filter(limit_choices_to)
if has_through_model: if has_through_model:
kwargs['read_only'] = True kwargs["read_only"] = True
kwargs.pop('queryset', None) kwargs.pop("queryset", None)
if model_field: if model_field:
if model_field.verbose_name and needs_label(model_field, field_name): if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name) kwargs["label"] = capfirst(model_field.verbose_name)
help_text = model_field.help_text help_text = model_field.help_text
if help_text: if help_text:
kwargs['help_text'] = help_text kwargs["help_text"] = help_text
if not model_field.editable: if not model_field.editable:
kwargs['read_only'] = True kwargs["read_only"] = True
kwargs.pop('queryset', None) kwargs.pop("queryset", None)
if kwargs.get('read_only', False): if kwargs.get("read_only", False):
# If this field is read-only, then return early. # If this field is read-only, then return early.
# No further keyword arguments are valid. # No further keyword arguments are valid.
return kwargs return kwargs
if model_field.has_default() or model_field.blank or model_field.null: if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False kwargs["required"] = False
if model_field.null: if model_field.null:
kwargs['allow_null'] = True kwargs["allow_null"] = True
if model_field.validators: if model_field.validators:
kwargs['validators'] = model_field.validators kwargs["validators"] = model_field.validators
if getattr(model_field, 'unique', False): if getattr(model_field, "unique", False):
validator = UniqueValidator(queryset=model_field.model._default_manager) validator = UniqueValidator(queryset=model_field.model._default_manager)
kwargs['validators'] = kwargs.get('validators', []) + [validator] kwargs["validators"] = kwargs.get("validators", []) + [validator]
if to_many and not model_field.blank: if to_many and not model_field.blank:
kwargs['allow_empty'] = False kwargs["allow_empty"] = False
return kwargs return kwargs
def get_nested_relation_kwargs(relation_info): def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True} kwargs = {"read_only": True}
if relation_info.to_many: if relation_info.to_many:
kwargs['many'] = True kwargs["many"] = True
return kwargs return kwargs
def get_url_kwargs(model_field): def get_url_kwargs(model_field):
return { return {"view_name": get_detail_view_name(model_field)}
'view_name': get_detail_view_name(model_field)
}

View File

@ -18,7 +18,7 @@ def remove_trailing_string(content, trailing):
Used when generating names from view classes. Used when generating names from view classes.
""" """
if content.endswith(trailing) and content != trailing: if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)] return content[: -len(trailing)]
return content return content
@ -36,14 +36,14 @@ def dedent(content):
# unindent the content if needed # unindent the content if needed
if lines: if lines:
whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines]) whitespace_counts = min([len(line) - len(line.lstrip(" ")) for line in lines])
tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines]) tab_counts = min([len(line) - len(line.lstrip("\t")) for line in lines])
if whitespace_counts: if whitespace_counts:
whitespace_pattern = '^' + (' ' * whitespace_counts) whitespace_pattern = "^" + (" " * whitespace_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
elif tab_counts: elif tab_counts:
whitespace_pattern = '^' + ('\t' * tab_counts) whitespace_pattern = "^" + ("\t" * tab_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
return content.strip() return content.strip()
@ -52,9 +52,9 @@ def camelcase_to_spaces(content):
Translate 'CamelCaseNames' to 'Camel Case Names'. Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes. Used when generating names from view classes.
""" """
camelcase_boundary = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' camelcase_boundary = "(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))"
content = re.sub(camelcase_boundary, ' \\1', content).strip() content = re.sub(camelcase_boundary, " \\1", content).strip()
return ' '.join(content.split('_')).title() return " ".join(content.split("_")).title()
def markup_description(description): def markup_description(description):
@ -64,6 +64,6 @@ def markup_description(description):
if apply_markdown: if apply_markdown:
description = apply_markdown(description) description = apply_markdown(description)
else: else:
description = escape(description).replace('\n', '<br />') description = escape(description).replace("\n", "<br />")
description = '<p>' + description + '</p>' description = "<p>" + description + "</p>"
return mark_safe(description) return mark_safe(description)

View File

@ -9,10 +9,10 @@ from django.utils.datastructures import MultiValueDict
def is_html_input(dictionary): def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input, # MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key. # which may have more than one value for each key.
return hasattr(dictionary, 'getlist') return hasattr(dictionary, "getlist")
def parse_html_list(dictionary, prefix='', default=None): def parse_html_list(dictionary, prefix="", default=None):
""" """
Used to support list values in HTML forms. Used to support list values in HTML forms.
Supports lists of primitives and/or dictionaries. Supports lists of primitives and/or dictionaries.
@ -48,7 +48,7 @@ def parse_html_list(dictionary, prefix='', default=None):
:returns a list of objects, or the value specified in ``default`` if the list is empty :returns a list of objects, or the value specified in ``default`` if the list is empty
""" """
ret = {} ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) regex = re.compile(r"^%s\[([0-9]+)\](.*)$" % re.escape(prefix))
for field, value in dictionary.items(): for field, value in dictionary.items():
match = regex.match(field) match = regex.match(field)
if not match: if not match:
@ -66,7 +66,7 @@ def parse_html_list(dictionary, prefix='', default=None):
return [ret[item] for item in sorted(ret)] if ret else default return [ret[item] for item in sorted(ret)] if ret else default
def parse_html_dict(dictionary, prefix=''): def parse_html_dict(dictionary, prefix=""):
""" """
Used to support dictionary values in HTML forms. Used to support dictionary values in HTML forms.
@ -83,7 +83,7 @@ def parse_html_dict(dictionary, prefix=''):
} }
""" """
ret = MultiValueDict() ret = MultiValueDict()
regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) regex = re.compile(r"^%s\.(.+)$" % re.escape(prefix))
for field in dictionary: for field in dictionary:
match = regex.match(field) match = regex.match(field)
if not match: if not match:

View File

@ -5,20 +5,19 @@ from rest_framework import ISO_8601
def datetime_formats(formats): def datetime_formats(formats):
format = ', '.join(formats).replace( format = ", ".join(formats).replace(
ISO_8601, ISO_8601, "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
) )
return humanize_strptime(format) return humanize_strptime(format)
def date_formats(formats): def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DD') format = ", ".join(formats).replace(ISO_8601, "YYYY-MM-DD")
return humanize_strptime(format) return humanize_strptime(format)
def time_formats(formats): def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') format = ", ".join(formats).replace(ISO_8601, "hh:mm[:ss[.uuuuuu]]")
return humanize_strptime(format) return humanize_strptime(format)
@ -40,7 +39,7 @@ def humanize_strptime(format_string):
"%a": "[Mon-Sun]", "%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]", "%A": "[Monday-Sunday]",
"%p": "[AM|PM]", "%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]" "%z": "[+HHMM|-HHMM]",
} }
for key, val in mapping.items(): for key, val in mapping.items():
format_string = format_string.replace(key, val) format_string = format_string.replace(key, val)

View File

@ -13,28 +13,28 @@ import json # noqa
def strict_constant(o): def strict_constant(o):
raise ValueError('Out of range float values are not JSON compliant: ' + repr(o)) raise ValueError("Out of range float values are not JSON compliant: " + repr(o))
@functools.wraps(json.dump) @functools.wraps(json.dump)
def dump(*args, **kwargs): def dump(*args, **kwargs):
kwargs.setdefault('allow_nan', False) kwargs.setdefault("allow_nan", False)
return json.dump(*args, **kwargs) return json.dump(*args, **kwargs)
@functools.wraps(json.dumps) @functools.wraps(json.dumps)
def dumps(*args, **kwargs): def dumps(*args, **kwargs):
kwargs.setdefault('allow_nan', False) kwargs.setdefault("allow_nan", False)
return json.dumps(*args, **kwargs) return json.dumps(*args, **kwargs)
@functools.wraps(json.load) @functools.wraps(json.load)
def load(*args, **kwargs): def load(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant) kwargs.setdefault("parse_constant", strict_constant)
return json.load(*args, **kwargs) return json.load(*args, **kwargs)
@functools.wraps(json.loads) @functools.wraps(json.loads)
def loads(*args, **kwargs): def loads(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant) kwargs.setdefault("parse_constant", strict_constant)
return json.loads(*args, **kwargs) return json.loads(*args, **kwargs)

View File

@ -49,20 +49,30 @@ def order_by_precedence(media_type_lst):
@python_2_unicode_compatible @python_2_unicode_compatible
class _MediaType(object): class _MediaType(object):
def __init__(self, media_type_str): def __init__(self, media_type_str):
self.orig = '' if (media_type_str is None) else media_type_str self.orig = "" if (media_type_str is None) else media_type_str
self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING)) self.full_type, self.params = parse_header(
self.main_type, sep, self.sub_type = self.full_type.partition('/') self.orig.encode(HTTP_HEADER_ENCODING)
)
self.main_type, sep, self.sub_type = self.full_type.partition("/")
def match(self, other): def match(self, other):
"""Return true if this MediaType satisfies the given MediaType.""" """Return true if this MediaType satisfies the given MediaType."""
for key in self.params: for key in self.params:
if key != 'q' and other.params.get(key, None) != self.params.get(key, None): if key != "q" and other.params.get(key, None) != self.params.get(key, None):
return False return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: if (
self.sub_type != "*"
and other.sub_type != "*"
and other.sub_type != self.sub_type
):
return False return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: if (
self.main_type != "*"
and other.main_type != "*"
and other.main_type != self.main_type
):
return False return False
return True return True
@ -72,16 +82,16 @@ class _MediaType(object):
""" """
Return a precedence level from 0-3 for the media type given how specific it is. Return a precedence level from 0-3 for the media type given how specific it is.
""" """
if self.main_type == '*': if self.main_type == "*":
return 0 return 0
elif self.sub_type == '*': elif self.sub_type == "*":
return 1 return 1
elif not self.params or list(self.params) == ['q']: elif not self.params or list(self.params) == ["q"]:
return 2 return 2
return 3 return 3
def __str__(self): def __str__(self):
ret = "%s/%s" % (self.main_type, self.sub_type) ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items(): for key, val in self.params.items():
ret += "; %s=%s" % (key, val.decode('ascii')) ret += "; %s=%s" % (key, val.decode("ascii"))
return ret return ret

View File

@ -7,23 +7,30 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [ FieldInfo = namedtuple(
'model_field', "FieldResult",
'related_model', [
'to_many', "pk", # Model field instance
'to_field', "fields", # Dict of field name -> model field instance
'has_through_model', "forward_relations", # Dict of field name -> RelationInfo
'reverse' "reverse_relations", # Dict of field name -> RelationInfo
]) "fields_and_pk", # Shortcut for 'pk' + 'fields'
"relations", # Shortcut for 'forward_relations' + 'reverse_relations'
],
)
RelationInfo = namedtuple(
"RelationInfo",
[
"model_field",
"related_model",
"to_many",
"to_field",
"has_through_model",
"reverse",
],
)
def get_field_info(model): def get_field_info(model):
@ -41,8 +48,9 @@ def get_field_info(model):
fields_and_pk = _merge_fields_and_pk(pk, fields) fields_and_pk = _merge_fields_and_pk(pk, fields)
relationships = _merge_relationships(forward_relations, reverse_relations) relationships = _merge_relationships(forward_relations, reverse_relations)
return FieldInfo(pk, fields, forward_relations, reverse_relations, return FieldInfo(
fields_and_pk, relationships) pk, fields, forward_relations, reverse_relations, fields_and_pk, relationships
)
def _get_pk(opts): def _get_pk(opts):
@ -59,14 +67,16 @@ def _get_pk(opts):
def _get_fields(opts): def _get_fields(opts):
fields = OrderedDict() fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.remote_field]: for field in [
field for field in opts.fields if field.serialize and not field.remote_field
]:
fields[field.name] = field fields[field.name] = field
return fields return fields
def _get_to_field(field): def _get_to_field(field):
return getattr(field, 'to_fields', None) and field.to_fields[0] return getattr(field, "to_fields", None) and field.to_fields[0]
def _get_forward_relationships(opts): def _get_forward_relationships(opts):
@ -74,14 +84,16 @@ def _get_forward_relationships(opts):
Returns an `OrderedDict` of field names to `RelationInfo`. Returns an `OrderedDict` of field names to `RelationInfo`.
""" """
forward_relations = OrderedDict() forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.remote_field]: for field in [
field for field in opts.fields if field.serialize and field.remote_field
]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=field.remote_field.model, related_model=field.remote_field.model,
to_many=False, to_many=False,
to_field=_get_to_field(field), to_field=_get_to_field(field),
has_through_model=False, has_through_model=False,
reverse=False reverse=False,
) )
# Deal with forward many-to-many relationships. # Deal with forward many-to-many relationships.
@ -92,10 +104,8 @@ def _get_forward_relationships(opts):
to_many=True, to_many=True,
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,
has_through_model=( has_through_model=(not field.remote_field.through._meta.auto_created),
not field.remote_field.through._meta.auto_created reverse=False,
),
reverse=False
) )
return forward_relations return forward_relations
@ -115,11 +125,13 @@ def _get_reverse_relationships(opts):
to_many=relation.field.remote_field.multiple, to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field), to_field=_get_to_field(relation.field),
has_through_model=False, has_through_model=False,
reverse=True reverse=True,
) )
# Deal with reverse many-to-many relationships. # Deal with reverse many-to-many relationships.
all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many] all_related_many_to_many_objects = [
r for r in opts.related_objects if r.field.many_to_many
]
for relation in all_related_many_to_many_objects: for relation in all_related_many_to_many_objects:
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
@ -129,10 +141,10 @@ def _get_reverse_relationships(opts):
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,
has_through_model=( has_through_model=(
(getattr(relation.field.remote_field, 'through', None) is not None) and (getattr(relation.field.remote_field, "through", None) is not None)
not relation.field.remote_field.through._meta.auto_created and not relation.field.remote_field.through._meta.auto_created
), ),
reverse=True reverse=True,
) )
return reverse_relations return reverse_relations
@ -140,7 +152,7 @@ def _get_reverse_relationships(opts):
def _merge_fields_and_pk(pk, fields): def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict() fields_and_pk = OrderedDict()
fields_and_pk['pk'] = pk fields_and_pk["pk"] = pk
fields_and_pk[pk.name] = pk fields_and_pk[pk.name] = pk
fields_and_pk.update(fields) fields_and_pk.update(fields)
@ -149,8 +161,7 @@ def _merge_fields_and_pk(pk, fields):
def _merge_relationships(forward_relations, reverse_relations): def _merge_relationships(forward_relations, reverse_relations):
return OrderedDict( return OrderedDict(
list(forward_relations.items()) + list(forward_relations.items()) + list(reverse_relations.items())
list(reverse_relations.items())
) )
@ -158,4 +169,8 @@ def is_abstract_model(model):
""" """
Given a model class, returns a boolean True if it is abstract and False if it is not. Given a model class, returns a boolean True if it is abstract and False if it is not.
""" """
return hasattr(model, '_meta') and hasattr(model._meta, 'abstract') and model._meta.abstract return (
hasattr(model, "_meta")
and hasattr(model._meta, "abstract")
and model._meta.abstract
)

View File

@ -16,14 +16,10 @@ from rest_framework.compat import unicode_repr
def manager_repr(value): def manager_repr(value):
model = value.model model = value.model
opts = model._meta opts = model._meta
names_and_managers = [ names_and_managers = [(manager.name, manager) for manager in opts.managers]
(manager.name, manager)
for manager
in opts.managers
]
for manager_name, manager_instance in names_and_managers: for manager_name, manager_instance in names_and_managers:
if manager_instance == value: if manager_instance == value:
return '%s.%s.all()' % (model._meta.object_name, manager_name) return "%s.%s.all()" % (model._meta.object_name, manager_name)
return repr(value) return repr(value)
@ -45,7 +41,7 @@ def smart_repr(value):
# <django.core.validators.RegexValidator object at 0x1047af050> # <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as # Should be presented as
# <django.core.validators.RegexValidator object> # <django.core.validators.RegexValidator object>
value = re.sub(' at 0x[0-9A-Fa-f]{4,32}>', '>', value) value = re.sub(" at 0x[0-9A-Fa-f]{4,32}>", ">", value)
return value return value
@ -54,16 +50,15 @@ def field_repr(field, force_many=False):
kwargs = field._kwargs kwargs = field._kwargs
if force_many: if force_many:
kwargs = kwargs.copy() kwargs = kwargs.copy()
kwargs['many'] = True kwargs["many"] = True
kwargs.pop('child', None) kwargs.pop("child", None)
arg_string = ', '.join([smart_repr(val) for val in field._args]) arg_string = ", ".join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([ kwarg_string = ", ".join(
'%s=%s' % (key, smart_repr(val)) ["%s=%s" % (key, smart_repr(val)) for key, val in sorted(kwargs.items())]
for key, val in sorted(kwargs.items()) )
])
if arg_string and kwarg_string: if arg_string and kwarg_string:
arg_string += ', ' arg_string += ", "
if force_many: if force_many:
class_name = force_many.__class__.__name__ class_name = force_many.__class__.__name__
@ -74,8 +69,8 @@ def field_repr(field, force_many=False):
def serializer_repr(serializer, indent, force_many=None): def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':' ret = field_repr(serializer, force_many) + ":"
indent_str = ' ' * indent indent_str = " " * indent
if force_many: if force_many:
fields = force_many.fields fields = force_many.fields
@ -83,25 +78,27 @@ def serializer_repr(serializer, indent, force_many=None):
fields = serializer.fields fields = serializer.fields
for field_name, field in fields.items(): for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = ' ret += "\n" + indent_str + field_name + " = "
if hasattr(field, 'fields'): if hasattr(field, "fields"):
ret += serializer_repr(field, indent + 1) ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'): elif hasattr(field, "child"):
ret += list_repr(field, indent + 1) ret += list_repr(field, indent + 1)
elif hasattr(field, 'child_relation'): elif hasattr(field, "child_relation"):
ret += field_repr(field.child_relation, force_many=field.child_relation) ret += field_repr(field.child_relation, force_many=field.child_relation)
else: else:
ret += field_repr(field) ret += field_repr(field)
if serializer.validators: if serializer.validators:
ret += '\n' + indent_str + 'class Meta:' ret += "\n" + indent_str + "class Meta:"
ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators) ret += (
"\n" + indent_str + " validators = " + smart_repr(serializer.validators)
)
return ret return ret
def list_repr(serializer, indent): def list_repr(serializer, indent):
child = serializer.child child = serializer.child
if hasattr(child, 'fields'): if hasattr(child, "fields"):
return serializer_repr(serializer, indent, force_many=child) return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer) return field_repr(serializer)

View File

@ -16,7 +16,7 @@ class ReturnDict(OrderedDict):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer') self.serializer = kwargs.pop("serializer")
super(ReturnDict, self).__init__(*args, **kwargs) super(ReturnDict, self).__init__(*args, **kwargs)
def copy(self): def copy(self):
@ -39,7 +39,7 @@ class ReturnList(list):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer') self.serializer = kwargs.pop("serializer")
super(ReturnList, self).__init__(*args, **kwargs) super(ReturnList, self).__init__(*args, **kwargs)
def __repr__(self): def __repr__(self):
@ -58,7 +58,7 @@ class BoundField(object):
providing an API similar to Django forms and form fields. providing an API similar to Django forms and form fields.
""" """
def __init__(self, field, value, errors, prefix=''): def __init__(self, field, value, errors, prefix=""):
self._field = field self._field = field
self._prefix = prefix self._prefix = prefix
self.value = value self.value = value
@ -73,12 +73,13 @@ class BoundField(object):
return self._field.__class__ return self._field.__class__
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s value=%s errors=%s>' % ( return unicode_to_repr(
self.__class__.__name__, self.value, self.errors "<%s value=%s errors=%s>"
)) % (self.__class__.__name__, self.value, self.errors)
)
def as_form_field(self): def as_form_field(self):
value = '' if (self.value is None or self.value is False) else self.value value = "" if (self.value is None or self.value is False) else self.value
return self.__class__(self._field, value, self.errors, self._prefix) return self.__class__(self._field, value, self.errors, self._prefix)
@ -87,7 +88,7 @@ class JSONBoundField(BoundField):
value = self.value value = self.value
# When HTML form input is used and the input is not valid # When HTML form input is used and the input is not valid
# value will be a JSONString, rather than a JSON primitive. # value will be a JSONString, rather than a JSON primitive.
if not getattr(value, 'is_json_string', False): if not getattr(value, "is_json_string", False):
try: try:
value = json.dumps(self.value, sort_keys=True, indent=4) value = json.dumps(self.value, sort_keys=True, indent=4)
except (TypeError, ValueError): except (TypeError, ValueError):
@ -102,8 +103,8 @@ class NestedBoundField(BoundField):
`BoundField` that is used for serializer fields. `BoundField` that is used for serializer fields.
""" """
def __init__(self, field, value, errors, prefix=''): def __init__(self, field, value, errors, prefix=""):
if value is None or value is '': if value is None or value is "":
value = {} value = {}
super(NestedBoundField, self).__init__(field, value, errors, prefix) super(NestedBoundField, self).__init__(field, value, errors, prefix)
@ -115,9 +116,9 @@ class NestedBoundField(BoundField):
field = self.fields[key] field = self.fields[key]
value = self.value.get(key) if self.value else None value = self.value.get(key) if self.value else None
error = self.errors.get(key) if isinstance(self.errors, dict) else None error = self.errors.get(key) if isinstance(self.errors, dict) else None
if hasattr(field, 'fields'): if hasattr(field, "fields"):
return NestedBoundField(field, value, error, prefix=self.name + '.') return NestedBoundField(field, value, error, prefix=self.name + ".")
return BoundField(field, value, error, prefix=self.name + '.') return BoundField(field, value, error, prefix=self.name + ".")
def as_form_field(self): def as_form_field(self):
values = {} values = {}
@ -125,7 +126,9 @@ class NestedBoundField(BoundField):
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
values[key] = value values[key] = value
else: else:
values[key] = '' if (value is None or value is False) else force_text(value) values[key] = (
"" if (value is None or value is False) else force_text(value)
)
return self.__class__(self._field, values, self.errors, self._prefix) return self.__class__(self._field, values, self.errors, self._prefix)

View File

@ -39,9 +39,10 @@ class UniqueValidator(object):
Should be applied to an individual field on the serializer. Should be applied to an individual field on the serializer.
""" """
message = _('This field must be unique.')
def __init__(self, queryset, message=None, lookup='exact'): message = _("This field must be unique.")
def __init__(self, queryset, message=None, lookup="exact"):
self.queryset = queryset self.queryset = queryset
self.serializer_field = None self.serializer_field = None
self.message = message or self.message self.message = message or self.message
@ -56,13 +57,13 @@ class UniqueValidator(object):
# same as the serializer field name if `source=<>` is set. # same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1] self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation. # Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None) self.instance = getattr(serializer_field.parent, "instance", None)
def filter_queryset(self, value, queryset): def filter_queryset(self, value, queryset):
""" """
Filter the queryset to all instances matching the given attribute. Filter the queryset to all instances matching the given attribute.
""" """
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} filter_kwargs = {"%s__%s" % (self.field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset): def exclude_current_instance(self, queryset):
@ -79,13 +80,12 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset) queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset) queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset): if qs_exists(queryset):
raise ValidationError(self.message, code='unique') raise ValidationError(self.message, code="unique")
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % ( return unicode_to_repr(
self.__class__.__name__, "<%s(queryset=%s)>" % (self.__class__.__name__, smart_repr(self.queryset))
smart_repr(self.queryset) )
))
class UniqueTogetherValidator(object): class UniqueTogetherValidator(object):
@ -94,8 +94,9 @@ class UniqueTogetherValidator(object):
Should be applied to the serializer class, not to an individual field. Should be applied to the serializer class, not to an individual field.
""" """
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.') message = _("The fields {field_names} must make a unique set.")
missing_message = _("This field is required.")
def __init__(self, queryset, fields, message=None): def __init__(self, queryset, fields, message=None):
self.queryset = queryset self.queryset = queryset
@ -109,7 +110,7 @@ class UniqueTogetherValidator(object):
prior to the validation call being made. prior to the validation call being made.
""" """
# Determine the existing instance, if this is an update operation. # Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None) self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs): def enforce_required_fields(self, attrs):
""" """
@ -125,7 +126,7 @@ class UniqueTogetherValidator(object):
if field_name not in attrs if field_name not in attrs
} }
if missing_items: if missing_items:
raise ValidationError(missing_items, code='required') raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
""" """
@ -139,10 +140,7 @@ class UniqueTogetherValidator(object):
attrs[field_name] = getattr(self.instance, field_name) attrs[field_name] = getattr(self.instance, field_name)
# Determine the filter keyword arguments and filter the queryset. # Determine the filter keyword arguments and filter the queryset.
filter_kwargs = { filter_kwargs = {field_name: attrs[field_name] for field_name in self.fields}
field_name: attrs[field_name]
for field_name in self.fields
}
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset): def exclude_current_instance(self, attrs, queryset):
@ -165,21 +163,24 @@ class UniqueTogetherValidator(object):
value for field, value in attrs.items() if field in self.fields value for field, value in attrs.items() if field in self.fields
] ]
if None not in checked_values and qs_exists(queryset): if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields) field_names = ", ".join(self.fields)
message = self.message.format(field_names=field_names) message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique') raise ValidationError(message, code="unique")
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( return unicode_to_repr(
self.__class__.__name__, "<%s(queryset=%s, fields=%s)>"
smart_repr(self.queryset), % (
smart_repr(self.fields) self.__class__.__name__,
)) smart_repr(self.queryset),
smart_repr(self.fields),
)
)
class BaseUniqueForValidator(object): class BaseUniqueForValidator(object):
message = None message = None
missing_message = _('This field is required.') missing_message = _("This field is required.")
def __init__(self, queryset, field, date_field, message=None): def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset self.queryset = queryset
@ -197,7 +198,7 @@ class BaseUniqueForValidator(object):
self.field_name = serializer.fields[self.field].source_attrs[-1] self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation. # Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None) self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs): def enforce_required_fields(self, attrs):
""" """
@ -210,10 +211,10 @@ class BaseUniqueForValidator(object):
if field_name not in attrs if field_name not in attrs
} }
if missing_items: if missing_items:
raise ValidationError(missing_items, code='required') raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset): def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.') raise NotImplementedError("`filter_queryset` must be implemented.")
def exclude_current_instance(self, attrs, queryset): def exclude_current_instance(self, attrs, queryset):
""" """
@ -231,17 +232,18 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset) queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset): if qs_exists(queryset):
message = self.message.format(date_field=self.date_field) message = self.message.format(date_field=self.date_field)
raise ValidationError({ raise ValidationError({self.field: message}, code="unique")
self.field: message
}, code='unique')
def __repr__(self): def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( return unicode_to_repr(
self.__class__.__name__, "<%s(queryset=%s, field=%s, date_field=%s)>"
smart_repr(self.queryset), % (
smart_repr(self.field), self.__class__.__name__,
smart_repr(self.date_field) smart_repr(self.queryset),
)) smart_repr(self.field),
smart_repr(self.date_field),
)
)
class UniqueForDateValidator(BaseUniqueForValidator): class UniqueForDateValidator(BaseUniqueForValidator):
@ -253,9 +255,9 @@ class UniqueForDateValidator(BaseUniqueForValidator):
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[self.field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day filter_kwargs["%s__day" % self.date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs["%s__month" % self.date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
@ -268,7 +270,7 @@ class UniqueForMonthValidator(BaseUniqueForValidator):
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[self.field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month filter_kwargs["%s__month" % self.date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)
@ -281,5 +283,5 @@ class UniqueForYearValidator(BaseUniqueForValidator):
filter_kwargs = {} filter_kwargs = {}
filter_kwargs[self.field_name] = value filter_kwargs[self.field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs) return qs_filter(queryset, **filter_kwargs)

View File

@ -19,19 +19,20 @@ class BaseVersioning(object):
version_param = api_settings.VERSION_PARAM version_param = api_settings.VERSION_PARAM
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
msg = '{cls}.determine_version() must be implemented.' msg = "{cls}.determine_version() must be implemented."
raise NotImplementedError(msg.format( raise NotImplementedError(msg.format(cls=self.__class__.__name__))
cls=self.__class__.__name__
))
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
return _reverse(viewname, args, kwargs, request, format, **extra) return _reverse(viewname, args, kwargs, request, format, **extra)
def is_allowed_version(self, version): def is_allowed_version(self, version):
if not self.allowed_versions: if not self.allowed_versions:
return True return True
return ((version is not None and version == self.default_version) or return (version is not None and version == self.default_version) or (
(version in self.allowed_versions)) version in self.allowed_versions
)
class AcceptHeaderVersioning(BaseVersioning): class AcceptHeaderVersioning(BaseVersioning):
@ -40,6 +41,7 @@ class AcceptHeaderVersioning(BaseVersioning):
Host: example.com Host: example.com
Accept: application/json; version=1.0 Accept: application/json; version=1.0
""" """
invalid_version_message = _('Invalid version in "Accept" header.') invalid_version_message = _('Invalid version in "Accept" header.')
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
@ -71,7 +73,8 @@ class URLPathVersioning(BaseVersioning):
Host: example.com Host: example.com
Accept: application/json Accept: application/json
""" """
invalid_version_message = _('Invalid version in URL path.')
invalid_version_message = _("Invalid version in URL path.")
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
version = kwargs.get(self.version_param, self.default_version) version = kwargs.get(self.version_param, self.default_version)
@ -82,7 +85,9 @@ class URLPathVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message) raise exceptions.NotFound(self.invalid_version_message)
return version return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
if request.version is not None: if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs kwargs = {} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version kwargs[self.version_param] = request.version
@ -116,21 +121,26 @@ class NamespaceVersioning(BaseVersioning):
Host: example.com Host: example.com
Accept: application/json Accept: application/json
""" """
invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.')
invalid_version_message = _(
"Invalid version in URL path. Does not match any version namespace."
)
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
resolver_match = getattr(request, 'resolver_match', None) resolver_match = getattr(request, "resolver_match", None)
if resolver_match is None or not resolver_match.namespace: if resolver_match is None or not resolver_match.namespace:
return self.default_version return self.default_version
# Allow for possibly nested namespaces. # Allow for possibly nested namespaces.
possible_versions = resolver_match.namespace.split(':') possible_versions = resolver_match.namespace.split(":")
for version in possible_versions: for version in possible_versions:
if self.is_allowed_version(version): if self.is_allowed_version(version):
return version return version
raise exceptions.NotFound(self.invalid_version_message) raise exceptions.NotFound(self.invalid_version_message)
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
if request.version is not None: if request.version is not None:
viewname = self.get_versioned_viewname(viewname, request) viewname = self.get_versioned_viewname(viewname, request)
return super(NamespaceVersioning, self).reverse( return super(NamespaceVersioning, self).reverse(
@ -138,7 +148,7 @@ class NamespaceVersioning(BaseVersioning):
) )
def get_versioned_viewname(self, viewname, request): def get_versioned_viewname(self, viewname, request):
return request.version + ':' + viewname return request.version + ":" + viewname
class HostNameVersioning(BaseVersioning): class HostNameVersioning(BaseVersioning):
@ -147,11 +157,12 @@ class HostNameVersioning(BaseVersioning):
Host: v1.example.com Host: v1.example.com
Accept: application/json Accept: application/json
""" """
hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
invalid_version_message = _('Invalid version in hostname.') hostname_regex = re.compile(r"^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$")
invalid_version_message = _("Invalid version in hostname.")
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
hostname, separator, port = request.get_host().partition(':') hostname, separator, port = request.get_host().partition(":")
match = self.hostname_regex.match(hostname) match = self.hostname_regex.match(hostname)
if not match: if not match:
return self.default_version return self.default_version
@ -170,7 +181,8 @@ class QueryParameterVersioning(BaseVersioning):
Host: example.com Host: example.com
Accept: application/json Accept: application/json
""" """
invalid_version_message = _('Invalid version in query parameter.')
invalid_version_message = _("Invalid version in query parameter.")
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
version = request.query_params.get(self.version_param, self.default_version) version = request.query_params.get(self.version_param, self.default_version)
@ -178,7 +190,9 @@ class QueryParameterVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message) raise exceptions.NotFound(self.invalid_version_message)
return version return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
url = super(QueryParameterVersioning, self).reverse( url = super(QueryParameterVersioning, self).reverse(
viewname, args, kwargs, request, format, **extra viewname, args, kwargs, request, format, **extra
) )

View File

@ -29,19 +29,19 @@ def get_view_name(view):
This function is the default for the `VIEW_NAME_FUNCTION` setting. This function is the default for the `VIEW_NAME_FUNCTION` setting.
""" """
# Name may be set by some Views, such as a ViewSet. # Name may be set by some Views, such as a ViewSet.
name = getattr(view, 'name', None) name = getattr(view, "name", None)
if name is not None: if name is not None:
return name return name
name = view.__class__.__name__ name = view.__class__.__name__
name = formatting.remove_trailing_string(name, 'View') name = formatting.remove_trailing_string(name, "View")
name = formatting.remove_trailing_string(name, 'ViewSet') name = formatting.remove_trailing_string(name, "ViewSet")
name = formatting.camelcase_to_spaces(name) name = formatting.camelcase_to_spaces(name)
# Suffix may be set by some Views, such as a ViewSet. # Suffix may be set by some Views, such as a ViewSet.
suffix = getattr(view, 'suffix', None) suffix = getattr(view, "suffix", None)
if suffix: if suffix:
name += ' ' + suffix name += " " + suffix
return name return name
@ -54,9 +54,9 @@ def get_view_description(view, html=False):
This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
""" """
# Description may be set by some Views, such as a ViewSet. # Description may be set by some Views, such as a ViewSet.
description = getattr(view, 'description', None) description = getattr(view, "description", None)
if description is None: if description is None:
description = view.__class__.__doc__ or '' description = view.__class__.__doc__ or ""
description = formatting.dedent(smart_text(description)) description = formatting.dedent(smart_text(description))
if html: if html:
@ -65,7 +65,7 @@ def get_view_description(view, html=False):
def set_rollback(): def set_rollback():
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block: if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True) transaction.set_rollback(True)
@ -87,15 +87,15 @@ def exception_handler(exc, context):
if isinstance(exc, exceptions.APIException): if isinstance(exc, exceptions.APIException):
headers = {} headers = {}
if getattr(exc, 'auth_header', None): if getattr(exc, "auth_header", None):
headers['WWW-Authenticate'] = exc.auth_header headers["WWW-Authenticate"] = exc.auth_header
if getattr(exc, 'wait', None): if getattr(exc, "wait", None):
headers['Retry-After'] = '%d' % exc.wait headers["Retry-After"] = "%d" % exc.wait
if isinstance(exc.detail, (list, dict)): if isinstance(exc.detail, (list, dict)):
data = exc.detail data = exc.detail
else: else:
data = {'detail': exc.detail} data = {"detail": exc.detail}
set_rollback() set_rollback()
return Response(data, status=exc.status_code, headers=headers) return Response(data, status=exc.status_code, headers=headers)
@ -128,13 +128,15 @@ class APIView(View):
This allows us to discover information about the view when we do URL This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation. reverse lookups. Used for breadcrumb generation.
""" """
if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet): if isinstance(getattr(cls, "queryset", None), models.query.QuerySet):
def force_evaluation(): def force_evaluation():
raise RuntimeError( raise RuntimeError(
'Do not evaluate the `.queryset` attribute directly, ' "Do not evaluate the `.queryset` attribute directly, "
'as the result will be cached and reused between requests. ' "as the result will be cached and reused between requests. "
'Use `.all()` or call `.get_queryset()` instead.' "Use `.all()` or call `.get_queryset()` instead."
) )
cls.queryset._fetch_all = force_evaluation cls.queryset._fetch_all = force_evaluation
view = super(APIView, cls).as_view(**initkwargs) view = super(APIView, cls).as_view(**initkwargs)
@ -154,11 +156,9 @@ class APIView(View):
@property @property
def default_response_headers(self): def default_response_headers(self):
headers = { headers = {"Allow": ", ".join(self.allowed_methods)}
'Allow': ', '.join(self.allowed_methods),
}
if len(self.renderer_classes) > 1: if len(self.renderer_classes) > 1:
headers['Vary'] = 'Accept' headers["Vary"] = "Accept"
return headers return headers
def http_method_not_allowed(self, request, *args, **kwargs): def http_method_not_allowed(self, request, *args, **kwargs):
@ -199,9 +199,9 @@ class APIView(View):
# Note: Additionally `request` and `encoding` will also be added # Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object. # to the context by the Request object.
return { return {
'view': self, "view": self,
'args': getattr(self, 'args', ()), "args": getattr(self, "args", ()),
'kwargs': getattr(self, 'kwargs', {}) "kwargs": getattr(self, "kwargs", {}),
} }
def get_renderer_context(self): def get_renderer_context(self):
@ -212,10 +212,10 @@ class APIView(View):
# Note: Additionally 'response' will also be added to the context, # Note: Additionally 'response' will also be added to the context,
# by the Response object. # by the Response object.
return { return {
'view': self, "view": self,
'args': getattr(self, 'args', ()), "args": getattr(self, "args", ()),
'kwargs': getattr(self, 'kwargs', {}), "kwargs": getattr(self, "kwargs", {}),
'request': getattr(self, 'request', None) "request": getattr(self, "request", None),
} }
def get_exception_handler_context(self): def get_exception_handler_context(self):
@ -224,10 +224,10 @@ class APIView(View):
as the `context` argument. as the `context` argument.
""" """
return { return {
'view': self, "view": self,
'args': getattr(self, 'args', ()), "args": getattr(self, "args", ()),
'kwargs': getattr(self, 'kwargs', {}), "kwargs": getattr(self, "kwargs", {}),
'request': getattr(self, 'request', None) "request": getattr(self, "request", None),
} }
def get_view_name(self): def get_view_name(self):
@ -289,7 +289,7 @@ class APIView(View):
""" """
Instantiate and return the content negotiation class to use. Instantiate and return the content negotiation class to use.
""" """
if not getattr(self, '_negotiator', None): if not getattr(self, "_negotiator", None):
self._negotiator = self.content_negotiation_class() self._negotiator = self.content_negotiation_class()
return self._negotiator return self._negotiator
@ -333,7 +333,7 @@ class APIView(View):
for permission in self.get_permissions(): for permission in self.get_permissions():
if not permission.has_permission(request, self): if not permission.has_permission(request, self):
self.permission_denied( self.permission_denied(
request, message=getattr(permission, 'message', None) request, message=getattr(permission, "message", None)
) )
def check_object_permissions(self, request, obj): def check_object_permissions(self, request, obj):
@ -344,7 +344,7 @@ class APIView(View):
for permission in self.get_permissions(): for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj): if not permission.has_object_permission(request, self, obj):
self.permission_denied( self.permission_denied(
request, message=getattr(permission, 'message', None) request, message=getattr(permission, "message", None)
) )
def check_throttles(self, request): def check_throttles(self, request):
@ -379,7 +379,7 @@ class APIView(View):
parsers=self.get_parsers(), parsers=self.get_parsers(),
authenticators=self.get_authenticators(), authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(), negotiator=self.get_content_negotiator(),
parser_context=parser_context parser_context=parser_context,
) )
def initial(self, request, *args, **kwargs): def initial(self, request, *args, **kwargs):
@ -407,13 +407,12 @@ class APIView(View):
""" """
# Make the error obvious if a proper response is not returned # Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), ( assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` ' "Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` "
'to be returned from the view, but received a `%s`' "to be returned from the view, but received a `%s`" % type(response)
% type(response)
) )
if isinstance(response, Response): if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None): if not getattr(request, "accepted_renderer", None):
neg = self.perform_content_negotiation(request, force=True) neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg request.accepted_renderer, request.accepted_media_type = neg
@ -422,7 +421,7 @@ class APIView(View):
response.renderer_context = self.get_renderer_context() response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting. # Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None) vary_headers = self.headers.pop("Vary", None)
if vary_headers is not None: if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers)) patch_vary_headers(response, cc_delim_re.split(vary_headers))
@ -436,8 +435,9 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response, Handle any exception that occurs, by returning an appropriate response,
or re-raising the error. or re-raising the error.
""" """
if isinstance(exc, (exceptions.NotAuthenticated, if isinstance(
exceptions.AuthenticationFailed)): exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)
):
# WWW-Authenticate header for 401 responses, else coerce to 403 # WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request) auth_header = self.get_authenticate_header(self.request)
@ -460,8 +460,8 @@ class APIView(View):
def raise_uncaught_exception(self, exc): def raise_uncaught_exception(self, exc):
if settings.DEBUG: if settings.DEBUG:
request = self.request request = self.request
renderer_format = getattr(request.accepted_renderer, 'format') renderer_format = getattr(request.accepted_renderer, "format")
use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') use_plaintext_traceback = renderer_format not in ("html", "api", "admin")
request.force_plaintext_errors(use_plaintext_traceback) request.force_plaintext_errors(use_plaintext_traceback)
raise exc raise exc
@ -484,8 +484,9 @@ class APIView(View):
# Get the appropriate handler method # Get the appropriate handler method
if request.method.lower() in self.http_method_names: if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(), handler = getattr(
self.http_method_not_allowed) self, request.method.lower(), self.http_method_not_allowed
)
else: else:
handler = self.http_method_not_allowed handler = self.http_method_not_allowed

View File

@ -31,7 +31,7 @@ from rest_framework.reverse import reverse
def _is_extra_action(attr): def _is_extra_action(attr):
return hasattr(attr, 'mapping') return hasattr(attr, "mapping")
class ViewSetMixin(object): class ViewSetMixin(object):
@ -73,24 +73,30 @@ class ViewSetMixin(object):
# actions must not be empty # actions must not be empty
if not actions: if not actions:
raise TypeError("The `actions` argument must be provided when " raise TypeError(
"calling `.as_view()` on a ViewSet. For example " "The `actions` argument must be provided when "
"`.as_view({'get': 'list'})`") "calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`"
)
# sanitize keyword arguments # sanitize keyword arguments
for key in initkwargs: for key in initkwargs:
if key in cls.http_method_names: if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a " raise TypeError(
"keyword argument to %s(). Don't do that." "You tried to pass in the %s method name as a "
% (key, cls.__name__)) "keyword argument to %s(). Don't do that." % (key, cls.__name__)
)
if not hasattr(cls, key): if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % ( raise TypeError(
cls.__name__, key)) "%s() received an invalid keyword %r" % (cls.__name__, key)
)
# name and suffix are mutually exclusive # name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs: if "name" in initkwargs and "suffix" in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are " raise TypeError(
"mutually exclusive arguments." % (cls.__name__)) "%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__)
)
def view(request, *args, **kwargs): def view(request, *args, **kwargs):
self = cls(**initkwargs) self = cls(**initkwargs)
@ -105,7 +111,7 @@ class ViewSetMixin(object):
handler = getattr(self, action) handler = getattr(self, action)
setattr(self, method, handler) setattr(self, method, handler)
if hasattr(self, 'get') and not hasattr(self, 'head'): if hasattr(self, "get") and not hasattr(self, "head"):
self.head = self.get self.head = self.get
self.request = request self.request = request
@ -136,11 +142,11 @@ class ViewSetMixin(object):
""" """
request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs) request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs)
method = request.method.lower() method = request.method.lower()
if method == 'options': if method == "options":
# This is a special case as we always provide handling for the # This is a special case as we always provide handling for the
# options method in the base `View` class. # options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit. # Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata' self.action = "metadata"
else: else:
self.action = self.action_map.get(method) self.action = self.action_map.get(method)
return request return request
@ -149,8 +155,8 @@ class ViewSetMixin(object):
""" """
Reverse the action for the given `url_name`. Reverse the action for the given `url_name`.
""" """
url_name = '%s-%s' % (self.basename, url_name) url_name = "%s-%s" % (self.basename, url_name)
kwargs.setdefault('request', self.request) kwargs.setdefault("request", self.request)
return reverse(url_name, *args, **kwargs) return reverse(url_name, *args, **kwargs)
@ -175,13 +181,14 @@ class ViewSetMixin(object):
# filter for the relevant extra actions # filter for the relevant extra actions
actions = [ actions = [
action for action in self.get_extra_actions() action
for action in self.get_extra_actions()
if action.detail == self.detail if action.detail == self.detail
] ]
for action in actions: for action in actions:
try: try:
url_name = '%s-%s' % (self.basename, action.url_name) url_name = "%s-%s" % (self.basename, action.url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request) url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs) view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url action_urls[view.get_view_name()] = url
@ -195,6 +202,7 @@ class ViewSet(ViewSetMixin, views.APIView):
""" """
The base ViewSet class does not provide any actions by default. The base ViewSet class does not provide any actions by default.
""" """
pass pass
@ -204,26 +212,31 @@ class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
but does include the base set of generic view behavior, such as but does include the base set of generic view behavior, such as
the `get_object` and `get_queryset` methods. the `get_object` and `get_queryset` methods.
""" """
pass pass
class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, class ReadOnlyModelViewSet(
mixins.ListModelMixin, mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet
GenericViewSet): ):
""" """
A viewset that provides default `list()` and `retrieve()` actions. A viewset that provides default `list()` and `retrieve()` actions.
""" """
pass pass
class ModelViewSet(mixins.CreateModelMixin, class ModelViewSet(
mixins.RetrieveModelMixin, mixins.CreateModelMixin,
mixins.UpdateModelMixin, mixins.RetrieveModelMixin,
mixins.DestroyModelMixin, mixins.UpdateModelMixin,
mixins.ListModelMixin, mixins.DestroyModelMixin,
GenericViewSet): mixins.ListModelMixin,
GenericViewSet,
):
""" """
A viewset that provides default `create()`, `retrieve()`, `update()`, A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions. `partial_update()`, `destroy()` and `list()` actions.
""" """
pass pass

View File

@ -6,16 +6,23 @@ import sys
import pytest import pytest
PYTEST_ARGS = {
'default': [],
'fast': ['-q'],
}
FLAKE8_ARGS = ['rest_framework', 'tests'] PYTEST_ARGS = {"default": [], "fast": ["-q"]}
ISORT_ARGS = ['--recursive', '--check-only', '--diff', '-o' 'uritemplate', '-p', 'tests', 'rest_framework', 'tests'] FLAKE8_ARGS = ["rest_framework", "tests"]
BLACK_ARGS = ['--check', '--verbose'] ISORT_ARGS = [
"--recursive",
"--check-only",
"--diff",
"-o" "uritemplate",
"-p",
"tests",
"rest_framework",
"tests",
]
BLACK_ARGS = ["--check", "--verbose"]
def exit_on_failure(ret, message=None): def exit_on_failure(ret, message=None):
@ -24,43 +31,48 @@ def exit_on_failure(ret, message=None):
def flake8_main(args): def flake8_main(args):
print('Running flake8 code linting') print("Running flake8 code linting")
ret = subprocess.call(['flake8'] + args) ret = subprocess.call(["flake8"] + args)
print('flake8 failed' if ret else 'flake8 passed') print("flake8 failed" if ret else "flake8 passed")
return ret return ret
def isort_main(args): def isort_main(args):
print('Running isort code checking') print("Running isort code checking")
ret = subprocess.call(['isort'] + args) ret = subprocess.call(["isort"] + args)
if ret: if ret:
print('isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`') print(
"isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`"
)
else: else:
print('isort passed') print("isort passed")
return ret return ret
def black_main(args): def black_main(args):
print('Running black code checking') print("Running black code checking")
ret = subprocess.call(['black', '.'] + args) ret = subprocess.call(["black", "."] + args)
if ret: if ret:
print('black failed: Some code have incorrectly formatted. Fix by running `black .`') print(
"black failed: Some code have incorrectly formatted. Fix by running `black .`"
)
else: else:
print('black passed') print("black passed")
return ret return ret
def split_class_and_function(string): def split_class_and_function(string):
class_string, function_string = string.split('.', 1) class_string, function_string = string.split(".", 1)
return "%s and %s" % (class_string, function_string) return "%s and %s" % (class_string, function_string)
def is_function(string): def is_function(string):
# `True` if it looks like a test function is included in the string. # `True` if it looks like a test function is included in the string.
return string.startswith('test_') or '.test_' in string return string.startswith("test_") or ".test_" in string
def is_class(string): def is_class(string):
@ -70,7 +82,7 @@ def is_class(string):
if __name__ == "__main__": if __name__ == "__main__":
try: try:
sys.argv.remove('--nolint') sys.argv.remove("--nolint")
except ValueError: except ValueError:
run_black = True run_black = True
run_flake8 = True run_flake8 = True
@ -81,18 +93,18 @@ if __name__ == "__main__":
run_isort = False run_isort = False
try: try:
sys.argv.remove('--lintonly') sys.argv.remove("--lintonly")
except ValueError: except ValueError:
run_tests = True run_tests = True
else: else:
run_tests = False run_tests = False
try: try:
sys.argv.remove('--fast') sys.argv.remove("--fast")
except ValueError: except ValueError:
style = 'default' style = "default"
else: else:
style = 'fast' style = "fast"
run_black = False run_black = False
run_flake8 = False run_flake8 = False
run_isort = False run_isort = False
@ -102,26 +114,23 @@ if __name__ == "__main__":
first_arg = pytest_args[0] first_arg = pytest_args[0]
try: try:
pytest_args.remove('--coverage') pytest_args.remove("--coverage")
except ValueError: except ValueError:
pass pass
else: else:
pytest_args = [ pytest_args = ["--cov", ".", "--cov-report", "xml"] + pytest_args
'--cov', '.',
'--cov-report', 'xml',
] + pytest_args
if first_arg.startswith('-'): if first_arg.startswith("-"):
# `runtests.py [flags]` # `runtests.py [flags]`
pytest_args = ['tests'] + pytest_args pytest_args = ["tests"] + pytest_args
elif is_class(first_arg) and is_function(first_arg): elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]` # `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg) expression = split_class_and_function(first_arg)
pytest_args = ['tests', '-k', expression] + pytest_args[1:] pytest_args = ["tests", "-k", expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg): elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]` # `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]` # `runtests.py test_function [flags]`
pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:] pytest_args = ["tests", "-k", pytest_args[0]] + pytest_args[1:]
else: else:
pytest_args = PYTEST_ARGS[style] pytest_args = PYTEST_ARGS[style]

View File

@ -9,16 +9,23 @@ addopts=--tb=short --strict -ra
testspath = tests testspath = tests
[flake8] [flake8]
ignore = E501 max-line-length = 120
ignore = E501, W503, E203
banned-modules = json = use from rest_framework.utils import json! banned-modules = json = use from rest_framework.utils import json!
[isort] [isort]
skip=.tox skip=.tox
atomic=true atomic=true
multi_line_output=5 multi_line_output=3
known_standard_library=types lines_after_imports = 2
black=types
combine_as_imports = true
known_third_party=pytest,_pytest,django,pytz known_third_party=pytest,_pytest,django,pytz
known_first_party=rest_framework known_first_party=rest_framework, tests
include_trailing_comma=true
line_length = 88
balanced_wrapping = true
sections = FUTURE, STDLIB, DJANGO, CMS, THIRDPARTY, FIRSTPARTY, LIB, LOCALFOLDER
[coverage:run] [coverage:run]
# NOTE: source is ignored with pytest-cov (but uses the same). # NOTE: source is ignored with pytest-cov (but uses the same).

View File

@ -10,21 +10,21 @@ from setuptools import find_packages, setup
def read(f): def read(f):
return open(f, 'r', encoding='utf-8').read() return open(f, "r", encoding="utf-8").read()
def get_version(package): def get_version(package):
""" """
Return package version as listed in `__version__` in `init.py`. Return package version as listed in `__version__` in `init.py`.
""" """
init_py = open(os.path.join(package, '__init__.py')).read() init_py = open(os.path.join(package, "__init__.py")).read()
return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1) return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
version = get_version('rest_framework') version = get_version("rest_framework")
if sys.argv[-1] == 'publish': if sys.argv[-1] == "publish":
if os.system("pip freeze | grep twine"): if os.system("pip freeze | grep twine"):
print("twine not installed.\nUse `pip install twine`.\nExiting.") print("twine not installed.\nUse `pip install twine`.\nExiting.")
sys.exit() sys.exit()
@ -33,48 +33,48 @@ if sys.argv[-1] == 'publish':
print("You probably want to also tag the version now:") print("You probably want to also tag the version now:")
print(" git tag -a %s -m 'version %s'" % (version, version)) print(" git tag -a %s -m 'version %s'" % (version, version))
print(" git push --tags") print(" git push --tags")
shutil.rmtree('dist') shutil.rmtree("dist")
shutil.rmtree('build') shutil.rmtree("build")
shutil.rmtree('djangorestframework.egg-info') shutil.rmtree("djangorestframework.egg-info")
sys.exit() sys.exit()
setup( setup(
name='djangorestframework', name="djangorestframework",
version=version, version=version,
url='https://www.django-rest-framework.org/', url="https://www.django-rest-framework.org/",
license='BSD', license="BSD",
description='Web APIs for Django, made easy.', description="Web APIs for Django, made easy.",
long_description=read('README.md'), long_description=read("README.md"),
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
author='Tom Christie', author="Tom Christie",
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*) author_email="tom@tomchristie.com", # SEE NOTE BELOW (*)
packages=find_packages(exclude=['tests*']), packages=find_packages(exclude=["tests*"]),
include_package_data=True, include_package_data=True,
install_requires=[], install_requires=[],
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
zip_safe=False, zip_safe=False,
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', "Development Status :: 5 - Production/Stable",
'Environment :: Web Environment', "Environment :: Web Environment",
'Framework :: Django', "Framework :: Django",
'Framework :: Django :: 1.11', "Framework :: Django :: 1.11",
'Framework :: Django :: 2.0', "Framework :: Django :: 2.0",
'Framework :: Django :: 2.1', "Framework :: Django :: 2.1",
'Framework :: Django :: 2.2', "Framework :: Django :: 2.2",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: BSD License', "License :: OSI Approved :: BSD License",
'Operating System :: OS Independent', "Operating System :: OS Independent",
'Programming Language :: Python', "Programming Language :: Python",
'Programming Language :: Python :: 2', "Programming Language :: Python :: 2",
'Programming Language :: Python :: 2.7', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 3.4",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Topic :: Internet :: WWW/HTTP', "Topic :: Internet :: WWW/HTTP",
] ],
) )
# (*) Please direct queries to the discussion group, rather than to me directly # (*) Please direct queries to the discussion group, rather than to me directly

View File

@ -9,16 +9,22 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='CustomToken', name="CustomToken",
fields=[ fields=[
('key', models.CharField(max_length=40, primary_key=True, serialize=False)), (
('user', models.OneToOneField(on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL)), "key",
models.CharField(max_length=40, primary_key=True, serialize=False),
),
(
"user",
models.OneToOneField(
on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL
),
),
], ],
), )
] ]

View File

@ -13,11 +13,18 @@ from django.test import TestCase, override_settings
from django.utils import six from django.utils import six
from rest_framework import ( from rest_framework import (
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status HTTP_HEADER_ENCODING,
exceptions,
permissions,
renderers,
status,
) )
from rest_framework.authentication import ( from rest_framework.authentication import (
BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, BaseAuthentication,
SessionAuthentication, TokenAuthentication BasicAuthentication,
RemoteUserAuthentication,
SessionAuthentication,
TokenAuthentication,
) )
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token from rest_framework.authtoken.views import obtain_auth_token
@ -27,6 +34,7 @@ from rest_framework.views import APIView
from .models import CustomToken from .models import CustomToken
factory = APIRequestFactory() factory = APIRequestFactory()
@ -35,92 +43,77 @@ class CustomTokenAuthentication(TokenAuthentication):
class CustomKeywordTokenAuthentication(TokenAuthentication): class CustomKeywordTokenAuthentication(TokenAuthentication):
keyword = 'Bearer' keyword = "Bearer"
class MockView(APIView): class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
def get(self, request): def get(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({"a": 1, "b": 2, "c": 3})
def post(self, request): def post(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({"a": 1, "b": 2, "c": 3})
def put(self, request): def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({"a": 1, "b": 2, "c": 3})
urlpatterns = [ urlpatterns = [
url( url(
r'^session/$', r"^session/$", MockView.as_view(authentication_classes=[SessionAuthentication])
MockView.as_view(authentication_classes=[SessionAuthentication]) ),
url(r"^basic/$", MockView.as_view(authentication_classes=[BasicAuthentication])),
url(
r"^remote-user/$",
MockView.as_view(authentication_classes=[RemoteUserAuthentication]),
),
url(r"^token/$", MockView.as_view(authentication_classes=[TokenAuthentication])),
url(
r"^customtoken/$",
MockView.as_view(authentication_classes=[CustomTokenAuthentication]),
), ),
url( url(
r'^basic/$', r"^customkeywordtoken/$",
MockView.as_view(authentication_classes=[BasicAuthentication]) MockView.as_view(authentication_classes=[CustomKeywordTokenAuthentication]),
), ),
url( url(r"^auth-token/$", obtain_auth_token),
r'^remote-user/$', url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
MockView.as_view(authentication_classes=[RemoteUserAuthentication])
),
url(
r'^token/$',
MockView.as_view(authentication_classes=[TokenAuthentication])
),
url(
r'^customtoken/$',
MockView.as_view(authentication_classes=[CustomTokenAuthentication])
),
url(
r'^customkeywordtoken/$',
MockView.as_view(
authentication_classes=[CustomKeywordTokenAuthentication]
)
),
url(r'^auth-token/$', obtain_auth_token),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
] ]
@override_settings(ROOT_URLCONF=__name__) @override_settings(ROOT_URLCONF=__name__)
class BasicAuthTests(TestCase): class BasicAuthTests(TestCase):
"""Basic authentication""" """Basic authentication"""
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username, self.email, self.password
)
def test_post_form_passing_basic_auth(self): def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
credentials = ('%s:%s' % (self.username, self.password)) credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode( base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING) credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING) ).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials auth = "Basic %s" % base64_credentials
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', "/basic/", {"example": "example"}, HTTP_AUTHORIZATION=auth
{'example': 'example'},
HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_post_json_passing_basic_auth(self): def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
credentials = ('%s:%s' % (self.username, self.password)) credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode( base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING) credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING) ).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials auth = "Basic %s" % base64_credentials
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
{'example': 'example'},
format='json',
HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -128,39 +121,34 @@ class BasicAuthTests(TestCase):
"""Ensure POSTing JSON over basic auth with incorrectly padded Base64 string is handled correctly""" """Ensure POSTing JSON over basic auth with incorrectly padded Base64 string is handled correctly"""
# regression test for issue in 'rest_framework.authentication.BasicAuthentication.authenticate' # regression test for issue in 'rest_framework.authentication.BasicAuthentication.authenticate'
# https://github.com/encode/django-rest-framework/issues/4089 # https://github.com/encode/django-rest-framework/issues/4089
auth = 'Basic =a=' auth = "Basic =a="
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
{'example': 'example'},
format='json',
HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_form_failing_basic_auth(self): def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails""" """Ensure POSTing form over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', {'example': 'example'}) response = self.csrf_client.post("/basic/", {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', "/basic/", {"example": "example"}, format="json"
{'example': 'example'},
format='json'
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response['WWW-Authenticate'] == 'Basic realm="api"' assert response["WWW-Authenticate"] == 'Basic realm="api"'
def test_fail_post_if_credentials_are_missing(self): def test_fail_post_if_credentials_are_missing(self):
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', {'example': 'example'}, HTTP_AUTHORIZATION='Basic ') "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic "
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_credentials_contain_spaces(self): def test_fail_post_if_credentials_contain_spaces(self):
response = self.csrf_client.post( response = self.csrf_client.post(
'/basic/', {'example': 'example'}, "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic foo bar"
HTTP_AUTHORIZATION='Basic foo bar'
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -168,15 +156,14 @@ class BasicAuthTests(TestCase):
@override_settings(ROOT_URLCONF=__name__) @override_settings(ROOT_URLCONF=__name__)
class SessionAuthTests(TestCase): class SessionAuthTests(TestCase):
"""User session authentication""" """User session authentication"""
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.non_csrf_client = APIClient(enforce_csrf_checks=False) self.non_csrf_client = APIClient(enforce_csrf_checks=False)
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username, self.email, self.password
)
def tearDown(self): def tearDown(self):
self.csrf_client.logout() self.csrf_client.logout()
@ -187,8 +174,8 @@ class SessionAuthTests(TestCase):
cf. [#1810](https://github.com/encode/django-rest-framework/pull/1810) cf. [#1810](https://github.com/encode/django-rest-framework/pull/1810)
""" """
response = self.csrf_client.get('/auth/login/') response = self.csrf_client.get("/auth/login/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert '<label for="id_username">Username:</label>' in content assert '<label for="id_username">Username:</label>' in content
def test_post_form_session_auth_failing_csrf(self): def test_post_form_session_auth_failing_csrf(self):
@ -196,7 +183,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails. Ensure POSTing form over session authentication without CSRF token fails.
""" """
self.csrf_client.login(username=self.username, password=self.password) self.csrf_client.login(username=self.username, password=self.password)
response = self.csrf_client.post('/session/', {'example': 'example'}) response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_post_form_session_auth_passing_csrf(self): def test_post_form_session_auth_passing_csrf(self):
@ -213,10 +200,9 @@ class SessionAuthTests(TestCase):
self.csrf_client.cookies[settings.CSRF_COOKIE_NAME] = token self.csrf_client.cookies[settings.CSRF_COOKIE_NAME] = token
# Post the token matching the cookie value # Post the token matching the cookie value
response = self.csrf_client.post('/session/', { response = self.csrf_client.post(
'example': 'example', "/session/", {"example": "example", "csrfmiddlewaretoken": token}
'csrfmiddlewaretoken': token, )
})
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_passing(self): def test_post_form_session_auth_passing(self):
@ -224,12 +210,8 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in Ensure POSTing form over session authentication with logged in
user and CSRF token passes. user and CSRF token passes.
""" """
self.non_csrf_client.login( self.non_csrf_client.login(username=self.username, password=self.password)
username=self.username, password=self.password response = self.non_csrf_client.post("/session/", {"example": "example"})
)
response = self.non_csrf_client.post(
'/session/', {'example': 'example'}
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_put_form_session_auth_passing(self): def test_put_form_session_auth_passing(self):
@ -237,38 +219,33 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with Ensure PUTting form over session authentication with
logged in user and CSRF token passes. logged in user and CSRF token passes.
""" """
self.non_csrf_client.login( self.non_csrf_client.login(username=self.username, password=self.password)
username=self.username, password=self.password response = self.non_csrf_client.put("/session/", {"example": "example"})
)
response = self.non_csrf_client.put(
'/session/', {'example': 'example'}
)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_failing(self): def test_post_form_session_auth_failing(self):
""" """
Ensure POSTing form over session authentication without logged in user fails. Ensure POSTing form over session authentication without logged in user fails.
""" """
response = self.csrf_client.post('/session/', {'example': 'example'}) response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
class BaseTokenAuthTests(object): class BaseTokenAuthTests(object):
"""Token authentication""" """Token authentication"""
model = None model = None
path = None path = None
header_prefix = 'Token ' header_prefix = "Token "
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username, self.email, self.password
)
self.key = 'abcd1234' self.key = "abcd1234"
self.token = self.model.objects.create(key=self.key, user=self.user) self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
@ -278,39 +255,41 @@ class BaseTokenAuthTests(object):
""" """
auth = self.header_prefix + self.key auth = self.header_prefix + self.key
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_fail_authentication_if_user_is_not_active(self): def test_fail_authentication_if_user_is_not_active(self):
user = User.objects.create_user('foo', 'bar', 'baz') user = User.objects.create_user("foo", "bar", "baz")
user.is_active = False user.is_active = False
user.save() user.save()
self.model.objects.create(key='foobar_token', user=user) self.model.objects.create(key="foobar_token", user=user)
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, self.path,
HTTP_AUTHORIZATION=self.header_prefix + 'foobar_token' {"example": "example"},
HTTP_AUTHORIZATION=self.header_prefix + "foobar_token",
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_form_passing_nonexistent_token_auth(self): def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key # use a nonexistent token key
auth = self.header_prefix + 'wxyz6789' auth = self.header_prefix + "wxyz6789"
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_is_missing(self): def test_fail_post_if_token_is_missing(self):
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, self.path, {"example": "example"}, HTTP_AUTHORIZATION=self.header_prefix
HTTP_AUTHORIZATION=self.header_prefix) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_contains_spaces(self): def test_fail_post_if_token_contains_spaces(self):
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, self.path,
HTTP_AUTHORIZATION=self.header_prefix + 'foo bar' {"example": "example"},
HTTP_AUTHORIZATION=self.header_prefix + "foo bar",
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -318,7 +297,7 @@ class BaseTokenAuthTests(object):
# add an 'invalid' unicode character # add an 'invalid' unicode character
auth = self.header_prefix + self.key + "¸" auth = self.header_prefix + self.key + "¸"
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -329,8 +308,7 @@ class BaseTokenAuthTests(object):
""" """
auth = self.header_prefix + self.key auth = self.header_prefix + self.key
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, self.path, {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
format='json', HTTP_AUTHORIZATION=auth
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -343,8 +321,10 @@ class BaseTokenAuthTests(object):
def func_to_test(): def func_to_test():
return self.csrf_client.post( return self.csrf_client.post(
self.path, {'example': 'example'}, self.path,
format='json', HTTP_AUTHORIZATION=auth {"example": "example"},
format="json",
HTTP_AUTHORIZATION=auth,
) )
self.assertNumQueries(1, func_to_test) self.assertNumQueries(1, func_to_test)
@ -353,7 +333,7 @@ class BaseTokenAuthTests(object):
""" """
Ensure POSTing form over token auth without correct credentials fails Ensure POSTing form over token auth without correct credentials fails
""" """
response = self.csrf_client.post(self.path, {'example': 'example'}) response = self.csrf_client.post(self.path, {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
@ -361,7 +341,7 @@ class BaseTokenAuthTests(object):
Ensure POSTing json over token auth without correct credentials fails Ensure POSTing json over token auth without correct credentials fails
""" """
response = self.csrf_client.post( response = self.csrf_client.post(
self.path, {'example': 'example'}, format='json' self.path, {"example": "example"}, format="json"
) )
assert response.status_code == status.HTTP_401_UNAUTHORIZED assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -369,7 +349,7 @@ class BaseTokenAuthTests(object):
@override_settings(ROOT_URLCONF=__name__) @override_settings(ROOT_URLCONF=__name__)
class TokenAuthTests(BaseTokenAuthTests, TestCase): class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token model = Token
path = '/token/' path = "/token/"
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key""" """Ensure creating a token with no key will auto-assign a key"""
@ -387,12 +367,12 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
"""Ensure token login view using JSON POST works.""" """Ensure token login view using JSON POST works."""
client = APIClient(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post( response = client.post(
'/auth-token/', "/auth-token/",
{'username': self.username, 'password': self.password}, {"username": self.username, "password": self.password},
format='json' format="json",
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data['token'] == self.key assert response.data["token"] == self.key
def test_token_login_json_bad_creds(self): def test_token_login_json_bad_creds(self):
""" """
@ -401,41 +381,41 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
""" """
client = APIClient(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post( response = client.post(
'/auth-token/', "/auth-token/",
{'username': self.username, 'password': "badpass"}, {"username": self.username, "password": "badpass"},
format='json' format="json",
) )
assert response.status_code == 400 assert response.status_code == 400
def test_token_login_json_missing_fields(self): def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields.""" """Ensure token login view using JSON POST fails if missing fields."""
client = APIClient(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post(
{'username': self.username}, format='json') "/auth-token/", {"username": self.username}, format="json"
)
assert response.status_code == 400 assert response.status_code == 400
def test_token_login_form(self): def test_token_login_form(self):
"""Ensure token login view using form POST works.""" """Ensure token login view using form POST works."""
client = APIClient(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post( response = client.post(
'/auth-token/', "/auth-token/", {"username": self.username, "password": self.password}
{'username': self.username, 'password': self.password}
) )
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data['token'] == self.key assert response.data["token"] == self.key
@override_settings(ROOT_URLCONF=__name__) @override_settings(ROOT_URLCONF=__name__)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken model = CustomToken
path = '/customtoken/' path = "/customtoken/"
@override_settings(ROOT_URLCONF=__name__) @override_settings(ROOT_URLCONF=__name__)
class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase): class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token model = Token
path = '/customkeywordtoken/' path = "/customkeywordtoken/"
header_prefix = 'Bearer ' header_prefix = "Bearer "
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
@ -445,42 +425,42 @@ class IncorrectCredentialsTests(TestCase):
authentication should run and error, even if no permissions authentication should run and error, even if no permissions
are set on the view. are set on the view.
""" """
class IncorrectCredentialsAuth(BaseAuthentication): class IncorrectCredentialsAuth(BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
raise exceptions.AuthenticationFailed('Bad credentials') raise exceptions.AuthenticationFailed("Bad credentials")
request = factory.get('/') request = factory.get("/")
view = MockView.as_view( view = MockView.as_view(
authentication_classes=(IncorrectCredentialsAuth,), authentication_classes=(IncorrectCredentialsAuth,), permission_classes=()
permission_classes=()
) )
response = view(request) response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data == {'detail': 'Bad credentials'} assert response.data == {"detail": "Bad credentials"}
class FailingAuthAccessedInRenderer(TestCase): class FailingAuthAccessedInRenderer(TestCase):
def setUp(self): def setUp(self):
class AuthAccessingRenderer(renderers.BaseRenderer): class AuthAccessingRenderer(renderers.BaseRenderer):
media_type = 'text/plain' media_type = "text/plain"
format = 'txt' format = "txt"
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request'] request = renderer_context["request"]
if request.user.is_authenticated: if request.user.is_authenticated:
return b'authenticated' return b"authenticated"
return b'not authenticated' return b"not authenticated"
class FailingAuth(BaseAuthentication): class FailingAuth(BaseAuthentication):
def authenticate(self, request): def authenticate(self, request):
raise exceptions.AuthenticationFailed('authentication failed') raise exceptions.AuthenticationFailed("authentication failed")
class ExampleView(APIView): class ExampleView(APIView):
authentication_classes = (FailingAuth,) authentication_classes = (FailingAuth,)
renderer_classes = (AuthAccessingRenderer,) renderer_classes = (AuthAccessingRenderer,)
def get(self, request): def get(self, request):
return Response({'foo': 'bar'}) return Response({"foo": "bar"})
self.view = ExampleView.as_view() self.view = ExampleView.as_view()
@ -490,10 +470,10 @@ class FailingAuthAccessedInRenderer(TestCase):
`request.user` without raising an exception. Particularly relevant `request.user` without raising an exception. Particularly relevant
to HTML responses that might reasonably access `request.user`. to HTML responses that might reasonably access `request.user`.
""" """
request = factory.get('/') request = factory.get("/")
response = self.view(request) response = self.view(request)
content = response.render().content content = response.render().content
assert content == b'not authenticated' assert content == b"not authenticated"
class NoAuthenticationClassesTests(TestCase): class NoAuthenticationClassesTests(TestCase):
@ -505,23 +485,21 @@ class NoAuthenticationClassesTests(TestCase):
""" """
class DummyPermission(permissions.BasePermission): class DummyPermission(permissions.BasePermission):
message = 'Dummy permission message' message = "Dummy permission message"
def has_permission(self, request, view): def has_permission(self, request, view):
return False return False
request = factory.get('/') request = factory.get("/")
view = MockView.as_view( view = MockView.as_view(
authentication_classes=(), authentication_classes=(), permission_classes=(DummyPermission,)
permission_classes=(DummyPermission,),
) )
response = view(request) response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data == {'detail': 'Dummy permission message'} assert response.data == {"detail": "Dummy permission message"}
class BasicAuthenticationUnitTests(TestCase): class BasicAuthenticationUnitTests(TestCase):
def test_base_authentication_abstract_method(self): def test_base_authentication_abstract_method(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
BaseAuthentication().authenticate({}) BaseAuthentication().authenticate({})
@ -529,34 +507,34 @@ class BasicAuthenticationUnitTests(TestCase):
def test_basic_authentication_raises_error_if_user_not_found(self): def test_basic_authentication_raises_error_if_user_not_found(self):
auth = BasicAuthentication() auth = BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed): with pytest.raises(exceptions.AuthenticationFailed):
auth.authenticate_credentials('invalid id', 'invalid password') auth.authenticate_credentials("invalid id", "invalid password")
def test_basic_authentication_raises_error_if_user_not_active(self): def test_basic_authentication_raises_error_if_user_not_active(self):
from rest_framework import authentication from rest_framework import authentication
class MockUser(object): class MockUser(object):
is_active = False is_active = False
old_authenticate = authentication.authenticate old_authenticate = authentication.authenticate
authentication.authenticate = lambda **kwargs: MockUser() authentication.authenticate = lambda **kwargs: MockUser()
auth = authentication.BasicAuthentication() auth = authentication.BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed) as error: with pytest.raises(exceptions.AuthenticationFailed) as error:
auth.authenticate_credentials('foo', 'bar') auth.authenticate_credentials("foo", "bar")
assert 'User inactive or deleted.' in str(error) assert "User inactive or deleted." in str(error)
authentication.authenticate = old_authenticate authentication.authenticate = old_authenticate
@override_settings(ROOT_URLCONF=__name__, @override_settings(
AUTHENTICATION_BACKENDS=('django.contrib.auth.backends.RemoteUserBackend',)) ROOT_URLCONF=__name__,
AUTHENTICATION_BACKENDS=("django.contrib.auth.backends.RemoteUserBackend",),
)
class RemoteUserAuthenticationUnitTests(TestCase): class RemoteUserAuthenticationUnitTests(TestCase):
def setUp(self): def setUp(self):
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username, self.email, self.password
)
def test_remote_user_works(self): def test_remote_user_works(self):
response = self.client.post('/remote-user/', response = self.client.post("/remote-user/", REMOTE_USER=self.username)
REMOTE_USER=self.username)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -4,7 +4,8 @@ from django.conf.urls import include, url
from .views import MockView from .views import MockView
urlpatterns = [ urlpatterns = [
url(r'^$', MockView.as_view()), url(r"^$", MockView.as_view()),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
] ]

View File

@ -4,6 +4,5 @@ from django.conf.urls import url
from .views import MockView from .views import MockView
urlpatterns = [
url(r'^$', MockView.as_view()), urlpatterns = [url(r"^$", MockView.as_view())]
]

View File

@ -6,71 +6,65 @@ from django.test import TestCase, override_settings
from rest_framework.test import APIClient from rest_framework.test import APIClient
@override_settings(ROOT_URLCONF='tests.browsable_api.auth_urls') @override_settings(ROOT_URLCONF="tests.browsable_api.auth_urls")
class DropdownWithAuthTests(TestCase): class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled.""" """Tests correct dropdown behaviour with Auth views enabled."""
def setUp(self): def setUp(self):
self.client = APIClient(enforce_csrf_checks=True) self.client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username,
self.email,
self.password
)
def tearDown(self): def tearDown(self):
self.client.logout() self.client.logout()
def test_name_shown_when_logged_in(self): def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert 'john' in content assert "john" in content
def test_logout_shown_when_logged_in(self): def test_logout_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert '>Log out<' in content assert ">Log out<" in content
def test_login_shown_when_logged_out(self): def test_login_shown_when_logged_out(self):
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert '>Log in<' in content assert ">Log in<" in content
@override_settings(ROOT_URLCONF='tests.browsable_api.no_auth_urls') @override_settings(ROOT_URLCONF="tests.browsable_api.no_auth_urls")
class NoDropdownWithoutAuthTests(TestCase): class NoDropdownWithoutAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views NOT enabled.""" """Tests correct dropdown behaviour with Auth views NOT enabled."""
def setUp(self): def setUp(self):
self.client = APIClient(enforce_csrf_checks=True) self.client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = "john"
self.email = 'lennon@thebeatles.com' self.email = "lennon@thebeatles.com"
self.password = 'password' self.password = "password"
self.user = User.objects.create_user( self.user = User.objects.create_user(self.username, self.email, self.password)
self.username,
self.email,
self.password
)
def tearDown(self): def tearDown(self):
self.client.logout() self.client.logout()
def test_name_shown_when_logged_in(self): def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert 'john' in content assert "john" in content
def test_dropdown_not_shown_when_logged_in(self): def test_dropdown_not_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password) self.client.login(username=self.username, password=self.password)
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert '<li class="dropdown">' not in content assert '<li class="dropdown">' not in content
def test_dropdown_not_shown_when_logged_out(self): def test_dropdown_not_shown_when_logged_out(self):
response = self.client.get('/') response = self.client.get("/")
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert '<li class="dropdown">' not in content assert '<li class="dropdown">' not in content

View File

@ -19,24 +19,22 @@ class NestedSerializerTestSerializer(serializers.Serializer):
class NestedSerializersView(ListCreateAPIView): class NestedSerializersView(ListCreateAPIView):
renderer_classes = (BrowsableAPIRenderer, ) renderer_classes = (BrowsableAPIRenderer,)
serializer_class = NestedSerializerTestSerializer serializer_class = NestedSerializerTestSerializer
queryset = [{'nested': {'one': 1, 'two': 2}}] queryset = [{"nested": {"one": 1, "two": 2}}]
urlpatterns = [ urlpatterns = [url(r"^api/$", NestedSerializersView.as_view(), name="api")]
url(r'^api/$', NestedSerializersView.as_view(), name='api'),
]
class DropdownWithAuthTests(TestCase): class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled.""" """Tests correct dropdown behaviour with Auth views enabled."""
@override_settings(ROOT_URLCONF='tests.browsable_api.test_browsable_nested_api') @override_settings(ROOT_URLCONF="tests.browsable_api.test_browsable_nested_api")
def test_login(self): def test_login(self):
response = self.client.get('/api/') response = self.client.get("/api/")
assert 200 == response.status_code assert 200 == response.status_code
content = response.content.decode('utf-8') content = response.content.decode("utf-8")
assert 'form action="/api/"' in content assert 'form action="/api/"' in content
assert 'input name="nested.one"' in content assert 'input name="nested.one"' in content
assert 'input name="nested.two"' in content assert 'input name="nested.two"' in content

View File

@ -5,13 +5,14 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import BasicModel from tests.models import BasicModel
factory = APIRequestFactory() factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer): class BasicSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
fields = '__all__' fields = "__all__"
class StandardPostView(generics.CreateAPIView): class StandardPostView(generics.CreateAPIView):
@ -39,19 +40,19 @@ class TestPostingListData(TestCase):
def test_json_response(self): def test_json_response(self):
# sanity check for non-browsable API responses # sanity check for non-browsable API responses
view = StandardPostView.as_view() view = StandardPostView.as_view()
request = factory.post('/', [{}], format='json') request = factory.post("/", [{}], format="json")
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue('non_field_errors' in response.data) self.assertTrue("non_field_errors" in response.data)
def test_browsable_api(self): def test_browsable_api(self):
view = StandardPostView.as_view() view = StandardPostView.as_view()
request = factory.post('/?format=api', [{}], format='json') request = factory.post("/?format=api", [{}], format="json")
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue('non_field_errors' in response.data) self.assertTrue("non_field_errors" in response.data)
class TestManyPostView(TestCase): class TestManyPostView(TestCase):
@ -59,14 +60,11 @@ class TestManyPostView(TestCase):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ["foo", "bar", "baz"]
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = ManyPostView.as_view() self.view = ManyPostView.as_view()
def test_post_many_post_view(self): def test_post_many_post_view(self):
@ -77,7 +75,7 @@ class TestManyPostView(TestCase):
Regression test for https://github.com/encode/django-rest-framework/pull/3164 Regression test for https://github.com/encode/django-rest-framework/pull/3164
""" """
data = {} data = {}
request = factory.post('/', data, format='json') request = factory.post("/", data, format="json")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK

View File

@ -10,4 +10,4 @@ class MockView(APIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get(self, request): def get(self, request):
return Response({'a': 1, 'b': 2, 'c': 3}) return Response({"a": 1, "b": 2, "c": 3})

View File

@ -6,13 +6,21 @@ from django.core import management
def pytest_addoption(parser): def pytest_addoption(parser):
parser.addoption('--no-pkgroot', action='store_true', default=False, parser.addoption(
help='Remove package root directory from sys.path, ensuring that ' "--no-pkgroot",
'rest_framework is imported from the installed site-packages. ' action="store_true",
'Used for testing the distribution.') default=False,
parser.addoption('--staticfiles', action='store_true', default=False, help="Remove package root directory from sys.path, ensuring that "
help='Run tests with static files collection, using manifest ' "rest_framework is imported from the installed site-packages. "
'staticfiles storage. Used for testing the distribution.') "Used for testing the distribution.",
)
parser.addoption(
"--staticfiles",
action="store_true",
default=False,
help="Run tests with static files collection, using manifest "
"staticfiles storage. Used for testing the distribution.",
)
def pytest_configure(config): def pytest_configure(config):
@ -21,49 +29,42 @@ def pytest_configure(config):
settings.configure( settings.configure(
DEBUG_PROPAGATE_EXCEPTIONS=True, DEBUG_PROPAGATE_EXCEPTIONS=True,
DATABASES={ DATABASES={
'default': { "default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:'
}
}, },
SITE_ID=1, SITE_ID=1,
SECRET_KEY='not very secret in tests', SECRET_KEY="not very secret in tests",
USE_I18N=True, USE_I18N=True,
USE_L10N=True, USE_L10N=True,
STATIC_URL='/static/', STATIC_URL="/static/",
ROOT_URLCONF='tests.urls', ROOT_URLCONF="tests.urls",
TEMPLATES=[ TEMPLATES=[
{ {
'BACKEND': 'django.template.backends.django.DjangoTemplates', "BACKEND": "django.template.backends.django.DjangoTemplates",
'APP_DIRS': True, "APP_DIRS": True,
'OPTIONS': { "OPTIONS": {"debug": True}, # We want template errors to raise
"debug": True, # We want template errors to raise }
}
},
], ],
MIDDLEWARE=( MIDDLEWARE=(
'django.middleware.common.CommonMiddleware', "django.middleware.common.CommonMiddleware",
'django.contrib.sessions.middleware.SessionMiddleware', "django.contrib.sessions.middleware.SessionMiddleware",
'django.contrib.auth.middleware.AuthenticationMiddleware', "django.contrib.auth.middleware.AuthenticationMiddleware",
'django.contrib.messages.middleware.MessageMiddleware', "django.contrib.messages.middleware.MessageMiddleware",
), ),
INSTALLED_APPS=( INSTALLED_APPS=(
'django.contrib.admin', "django.contrib.admin",
'django.contrib.auth', "django.contrib.auth",
'django.contrib.contenttypes', "django.contrib.contenttypes",
'django.contrib.sessions', "django.contrib.sessions",
'django.contrib.sites', "django.contrib.sites",
'django.contrib.staticfiles', "django.contrib.staticfiles",
'rest_framework', "rest_framework",
'rest_framework.authtoken', "rest_framework.authtoken",
'tests.authentication', "tests.authentication",
'tests.generic_relations', "tests.generic_relations",
'tests.importable', "tests.importable",
'tests', "tests",
),
PASSWORD_HASHERS=(
'django.contrib.auth.hashers.MD5PasswordHasher',
), ),
PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",),
) )
# guardian is optional # guardian is optional
@ -74,28 +75,32 @@ def pytest_configure(config):
else: else:
settings.ANONYMOUS_USER_ID = -1 settings.ANONYMOUS_USER_ID = -1
settings.AUTHENTICATION_BACKENDS = ( settings.AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend', "django.contrib.auth.backends.ModelBackend",
'guardian.backends.ObjectPermissionBackend', "guardian.backends.ObjectPermissionBackend",
)
settings.INSTALLED_APPS += (
'guardian',
) )
settings.INSTALLED_APPS += ("guardian",)
if config.getoption('--no-pkgroot'): if config.getoption("--no-pkgroot"):
sys.path.pop(0) sys.path.pop(0)
# import rest_framework before pytest re-adds the package root directory. # import rest_framework before pytest re-adds the package root directory.
import rest_framework import rest_framework
package_dir = os.path.join(os.getcwd(), 'rest_framework')
package_dir = os.path.join(os.getcwd(), "rest_framework")
assert not rest_framework.__file__.startswith(package_dir) assert not rest_framework.__file__.startswith(package_dir)
# Manifest storage will raise an exception if static files are not present (ie, a packaging failure). # Manifest storage will raise an exception if static files are not present (ie, a packaging failure).
if config.getoption('--staticfiles'): if config.getoption("--staticfiles"):
import rest_framework import rest_framework
settings.STATIC_ROOT = os.path.join(os.path.dirname(rest_framework.__file__), 'static-root')
settings.STATICFILES_STORAGE = 'django.contrib.staticfiles.storage.ManifestStaticFilesStorage' settings.STATIC_ROOT = os.path.join(
os.path.dirname(rest_framework.__file__), "static-root"
)
settings.STATICFILES_STORAGE = (
"django.contrib.staticfiles.storage.ManifestStaticFilesStorage"
)
django.setup() django.setup()
if config.getoption('--staticfiles'): if config.getoption("--staticfiles"):
management.call_command('collectstatic', verbosity=0, interactive=False) management.call_command("collectstatic", verbosity=0, interactive=False)

View File

@ -5,32 +5,59 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [("contenttypes", "0002_remove_content_type_name")]
('contenttypes', '0002_remove_content_type_name'),
]
operations = [ operations = [
migrations.CreateModel( migrations.CreateModel(
name='Bookmark', name="Bookmark",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('url', models.URLField()), "id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("url", models.URLField()),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='Note', name="Note",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('text', models.TextField()), "id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("text", models.TextField()),
], ],
), ),
migrations.CreateModel( migrations.CreateModel(
name='Tag', name="Tag",
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), (
('tag', models.SlugField()), "id",
('object_id', models.PositiveIntegerField()), models.AutoField(
('content_type', models.ForeignKey(on_delete=models.CASCADE, to='contenttypes.ContentType')), auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("tag", models.SlugField()),
("object_id", models.PositiveIntegerField()),
(
"content_type",
models.ForeignKey(
on_delete=models.CASCADE, to="contenttypes.ContentType"
),
),
], ],
), ),
] ]

View File

@ -1,8 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.contrib.contenttypes.fields import ( from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
GenericForeignKey, GenericRelation
)
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import models from django.db import models
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
@ -13,10 +11,11 @@ class Tag(models.Model):
""" """
Tags have a descriptive slug, and are attached to an arbitrary object. Tags have a descriptive slug, and are attached to an arbitrary object.
""" """
tag = models.SlugField() tag = models.SlugField()
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.PositiveIntegerField() object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') tagged_item = GenericForeignKey("content_type", "object_id")
def __str__(self): def __str__(self):
return self.tag return self.tag
@ -27,11 +26,12 @@ class Bookmark(models.Model):
""" """
A URL bookmark that may have multiple tags attached. A URL bookmark that may have multiple tags attached.
""" """
url = models.URLField() url = models.URLField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __str__(self): def __str__(self):
return 'Bookmark: %s' % self.url return "Bookmark: %s" % self.url
@python_2_unicode_compatible @python_2_unicode_compatible
@ -39,8 +39,9 @@ class Note(models.Model):
""" """
A textual note that may have multiple tags attached. A textual note that may have multiple tags attached.
""" """
text = models.TextField() text = models.TextField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __str__(self): def __str__(self):
return 'Note: %s' % self.text return "Note: %s" % self.text

View File

@ -9,11 +9,11 @@ from .models import Bookmark, Note, Tag
class TestGenericRelations(TestCase): class TestGenericRelations(TestCase):
def setUp(self): def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') self.bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/")
Tag.objects.create(tagged_item=self.bookmark, tag='django') Tag.objects.create(tagged_item=self.bookmark, tag="django")
Tag.objects.create(tagged_item=self.bookmark, tag='python') Tag.objects.create(tagged_item=self.bookmark, tag="python")
self.note = Note.objects.create(text='Remember the milk') self.note = Note.objects.create(text="Remember the milk")
Tag.objects.create(tagged_item=self.note, tag='reminder') Tag.objects.create(tagged_item=self.note, tag="reminder")
def test_generic_relation(self): def test_generic_relation(self):
""" """
@ -26,12 +26,12 @@ class TestGenericRelations(TestCase):
class Meta: class Meta:
model = Bookmark model = Bookmark
fields = ('tags', 'url') fields = ("tags", "url")
serializer = BookmarkSerializer(self.bookmark) serializer = BookmarkSerializer(self.bookmark)
expected = { expected = {
'tags': ['django', 'python'], "tags": ["django", "python"],
'url': 'https://www.djangoproject.com/' "url": "https://www.djangoproject.com/",
} }
assert serializer.data == expected assert serializer.data == expected
@ -46,21 +46,18 @@ class TestGenericRelations(TestCase):
class Meta: class Meta:
model = Tag model = Tag
fields = ('tag', 'tagged_item') fields = ("tag", "tagged_item")
serializer = TagSerializer(Tag.objects.all(), many=True) serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [ expected = [
{ {
'tag': 'django', "tag": "django",
'tagged_item': 'Bookmark: https://www.djangoproject.com/' "tagged_item": "Bookmark: https://www.djangoproject.com/",
}, },
{ {
'tag': 'python', "tag": "python",
'tagged_item': 'Bookmark: https://www.djangoproject.com/' "tagged_item": "Bookmark: https://www.djangoproject.com/",
}, },
{ {"tag": "reminder", "tagged_item": "Note: Remember the milk"},
'tag': 'reminder',
'tagged_item': 'Note: Remember the milk'
}
] ]
assert serializer.data == expected assert serializer.data == expected

View File

@ -5,9 +5,9 @@ from tests import importable
def test_installed(): def test_installed():
# ensure that apps can freely import rest_framework.compat # ensure that apps can freely import rest_framework.compat
assert 'tests.importable' in settings.INSTALLED_APPS assert "tests.importable" in settings.INSTALLED_APPS
def test_imported(): def test_imported():
# ensure that the __init__ hasn't been mucked with # ensure that the __init__ hasn't been mucked with
assert hasattr(importable, 'compat') assert hasattr(importable, "compat")

View File

@ -12,7 +12,7 @@ class RESTFrameworkModel(models.Model):
""" """
class Meta: class Meta:
app_label = 'tests' app_label = "tests"
abstract = True abstract = True
@ -20,7 +20,7 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField( text = models.CharField(
max_length=100, max_length=100,
verbose_name=_("Text comes here"), verbose_name=_("Text comes here"),
help_text=_("Text description.") help_text=_("Text description."),
) )
@ -32,7 +32,7 @@ class ManyToManyTarget(RESTFrameworkModel):
class ManyToManySource(RESTFrameworkModel): class ManyToManySource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
targets = models.ManyToManyField(ManyToManyTarget, related_name='sources') targets = models.ManyToManyField(ManyToManyTarget, related_name="sources")
# ForeignKey # ForeignKey
@ -47,51 +47,74 @@ class UUIDForeignKeyTarget(RESTFrameworkModel):
class ForeignKeySource(RESTFrameworkModel): class ForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, related_name='sources', target = models.ForeignKey(
help_text='Target', verbose_name='Target', ForeignKeyTarget,
on_delete=models.CASCADE) related_name="sources",
help_text="Target",
verbose_name="Target",
on_delete=models.CASCADE,
)
class ForeignKeySourceWithLimitedChoices(RESTFrameworkModel): class ForeignKeySourceWithLimitedChoices(RESTFrameworkModel):
target = models.ForeignKey(ForeignKeyTarget, help_text='Target', target = models.ForeignKey(
verbose_name='Target', ForeignKeyTarget,
limit_choices_to={"name__startswith": "limited-"}, help_text="Target",
on_delete=models.CASCADE) verbose_name="Target",
limit_choices_to={"name__startswith": "limited-"},
on_delete=models.CASCADE,
)
class ForeignKeySourceWithQLimitedChoices(RESTFrameworkModel): class ForeignKeySourceWithQLimitedChoices(RESTFrameworkModel):
target = models.ForeignKey(ForeignKeyTarget, help_text='Target', target = models.ForeignKey(
verbose_name='Target', ForeignKeyTarget,
limit_choices_to=models.Q(name__startswith="limited-"), help_text="Target",
on_delete=models.CASCADE) verbose_name="Target",
limit_choices_to=models.Q(name__startswith="limited-"),
on_delete=models.CASCADE,
)
# Nullable ForeignKey # Nullable ForeignKey
class NullableForeignKeySource(RESTFrameworkModel): class NullableForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, target = models.ForeignKey(
related_name='nullable_sources', ForeignKeyTarget,
verbose_name='Optional target object', null=True,
on_delete=models.CASCADE) blank=True,
related_name="nullable_sources",
verbose_name="Optional target object",
on_delete=models.CASCADE,
)
class NullableUUIDForeignKeySource(RESTFrameworkModel): class NullableUUIDForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True, target = models.ForeignKey(
related_name='nullable_sources', ForeignKeyTarget,
verbose_name='Optional target object', null=True,
on_delete=models.CASCADE) blank=True,
related_name="nullable_sources",
verbose_name="Optional target object",
on_delete=models.CASCADE,
)
class NestedForeignKeySource(RESTFrameworkModel): class NestedForeignKeySource(RESTFrameworkModel):
""" """
Used for testing FK chain. A -> B -> C. Used for testing FK chain. A -> B -> C.
""" """
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.ForeignKey(NullableForeignKeySource, null=True, blank=True, target = models.ForeignKey(
related_name='nested_sources', NullableForeignKeySource,
verbose_name='Intermediate target object', null=True,
on_delete=models.CASCADE) blank=True,
related_name="nested_sources",
verbose_name="Intermediate target object",
on_delete=models.CASCADE,
)
# OneToOne # OneToOne
@ -102,13 +125,21 @@ class OneToOneTarget(RESTFrameworkModel):
class NullableOneToOneSource(RESTFrameworkModel): class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.OneToOneField( target = models.OneToOneField(
OneToOneTarget, null=True, blank=True, OneToOneTarget,
related_name='nullable_source', on_delete=models.CASCADE) null=True,
blank=True,
related_name="nullable_source",
on_delete=models.CASCADE,
)
class OneToOnePKSource(RESTFrameworkModel): class OneToOnePKSource(RESTFrameworkModel):
""" Test model where the primary key is a OneToOneField with another model. """ """ Test model where the primary key is a OneToOneField with another model. """
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
target = models.OneToOneField( target = models.OneToOneField(
OneToOneTarget, primary_key=True, OneToOneTarget,
related_name='required_source', on_delete=models.CASCADE) primary_key=True,
related_name="required_source",
on_delete=models.CASCADE,
)

View File

@ -18,52 +18,75 @@ from rest_framework.views import APIView
def get_schema(): def get_schema():
return coreapi.Document( return coreapi.Document(
url='https://api.example.com/', url="https://api.example.com/",
title='Example API', title="Example API",
content={ content={
'simple_link': coreapi.Link('/example/', description='example link'), "simple_link": coreapi.Link("/example/", description="example link"),
'headers': coreapi.Link('/headers/'), "headers": coreapi.Link("/headers/"),
'location': { "location": {
'query': coreapi.Link('/example/', fields=[ "query": coreapi.Link(
coreapi.Field(name='example', schema=coreschema.String(description='example field')) "/example/",
]), fields=[
'form': coreapi.Link('/example/', action='post', fields=[ coreapi.Field(
coreapi.Field(name='example') name="example",
]), schema=coreschema.String(description="example field"),
'body': coreapi.Link('/example/', action='post', fields=[ )
coreapi.Field(name='example', location='body') ],
]), ),
'path': coreapi.Link('/example/{id}', fields=[ "form": coreapi.Link(
coreapi.Field(name='id', location='path') "/example/", action="post", fields=[coreapi.Field(name="example")]
]) ),
"body": coreapi.Link(
"/example/",
action="post",
fields=[coreapi.Field(name="example", location="body")],
),
"path": coreapi.Link(
"/example/{id}", fields=[coreapi.Field(name="id", location="path")]
),
}, },
'encoding': { "encoding": {
'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ "multipart": coreapi.Link(
coreapi.Field(name='example') "/example/",
]), action="post",
'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ encoding="multipart/form-data",
coreapi.Field(name='example', location='body') fields=[coreapi.Field(name="example")],
]), ),
'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ "multipart-body": coreapi.Link(
coreapi.Field(name='example') "/example/",
]), action="post",
'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ encoding="multipart/form-data",
coreapi.Field(name='example', location='body') fields=[coreapi.Field(name="example", location="body")],
]), ),
'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ "urlencoded": coreapi.Link(
coreapi.Field(name='example', location='body') "/example/",
]), action="post",
encoding="application/x-www-form-urlencoded",
fields=[coreapi.Field(name="example")],
),
"urlencoded-body": coreapi.Link(
"/example/",
action="post",
encoding="application/x-www-form-urlencoded",
fields=[coreapi.Field(name="example", location="body")],
),
"raw_upload": coreapi.Link(
"/upload/",
action="post",
encoding="application/octet-stream",
fields=[coreapi.Field(name="example", location="body")],
),
}, },
'response': { "response": {
'download': coreapi.Link('/download/'), "download": coreapi.Link("/download/"),
'text': coreapi.Link('/text/') "text": coreapi.Link("/text/"),
} },
} },
) )
def _iterlists(querydict): def _iterlists(querydict):
if hasattr(querydict, 'iterlists'): if hasattr(querydict, "iterlists"):
return querydict.iterlists() return querydict.iterlists()
return querydict.lists() return querydict.lists()
@ -73,8 +96,7 @@ def _get_query_params(request):
# than one item is present for a given key. # than one item is present for a given key.
return { return {
key: (value[0] if len(value) == 1 else value) key: (value[0] if len(value) == 1 else value)
for key, value in for key, value in _iterlists(request.query_params)
_iterlists(request.query_params)
} }
@ -83,7 +105,7 @@ def _get_data(request):
return request.data return request.data
# Coerce multidict into regular dict, and remove files to # Coerce multidict into regular dict, and remove files to
# make assertions simpler. # make assertions simpler.
if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'): if hasattr(request.data, "iterlists") or hasattr(request.data, "lists"):
# Use a list value if a QueryDict contains multiple items for a key. # Use a list value if a QueryDict contains multiple items for a key.
return { return {
key: value[0] if len(value) == 1 else value key: value[0] if len(value) == 1 else value
@ -91,9 +113,7 @@ def _get_data(request):
if key not in request.FILES if key not in request.FILES
} }
return { return {
key: value key: value for key, value in request.data.items() if key not in request.FILES
for key, value in request.data.items()
if key not in request.FILES
} }
@ -101,7 +121,7 @@ def _get_files(request):
if not request.FILES: if not request.FILES:
return {} return {}
return { return {
key: {'name': value.name, 'content': value.read()} key: {"name": value.name, "content": value.read()}
for key, value in request.FILES.items() for key, value in request.FILES.items()
} }
@ -116,210 +136,207 @@ class SchemaView(APIView):
class ListView(APIView): class ListView(APIView):
def get(self, request): def get(self, request):
return Response({ return Response(
'method': request.method, {"method": request.method, "query_params": _get_query_params(request)}
'query_params': _get_query_params(request) )
})
def post(self, request): def post(self, request):
if request.content_type: if request.content_type:
content_type = request.content_type.split(';')[0] content_type = request.content_type.split(";")[0]
else: else:
content_type = None content_type = None
return Response({ return Response(
'method': request.method, {
'query_params': _get_query_params(request), "method": request.method,
'data': _get_data(request), "query_params": _get_query_params(request),
'files': _get_files(request), "data": _get_data(request),
'content_type': content_type "files": _get_files(request),
}) "content_type": content_type,
}
)
class DetailView(APIView): class DetailView(APIView):
def get(self, request, id): def get(self, request, id):
return Response({ return Response(
'id': id, {
'method': request.method, "id": id,
'query_params': _get_query_params(request) "method": request.method,
}) "query_params": _get_query_params(request),
}
)
class UploadView(APIView): class UploadView(APIView):
parser_classes = [FileUploadParser] parser_classes = [FileUploadParser]
def post(self, request): def post(self, request):
return Response({ return Response(
'method': request.method, {
'files': _get_files(request), "method": request.method,
'content_type': request.content_type "files": _get_files(request),
}) "content_type": request.content_type,
}
)
class DownloadView(APIView): class DownloadView(APIView):
def get(self, request): def get(self, request):
return HttpResponse('some file content', content_type='image/png') return HttpResponse("some file content", content_type="image/png")
class TextView(APIView): class TextView(APIView):
def get(self, request): def get(self, request):
return HttpResponse('123', content_type='text/plain') return HttpResponse("123", content_type="text/plain")
class HeadersView(APIView): class HeadersView(APIView):
def get(self, request): def get(self, request):
headers = { headers = {
key[5:].replace('_', '-'): value key[5:].replace("_", "-"): value
for key, value in request.META.items() for key, value in request.META.items()
if key.startswith('HTTP_') if key.startswith("HTTP_")
} }
return Response({ return Response({"method": request.method, "headers": headers})
'method': request.method,
'headers': headers
})
urlpatterns = [ urlpatterns = [
url(r'^$', SchemaView.as_view()), url(r"^$", SchemaView.as_view()),
url(r'^example/$', ListView.as_view()), url(r"^example/$", ListView.as_view()),
url(r'^example/(?P<id>[0-9]+)/$', DetailView.as_view()), url(r"^example/(?P<id>[0-9]+)/$", DetailView.as_view()),
url(r'^upload/$', UploadView.as_view()), url(r"^upload/$", UploadView.as_view()),
url(r'^download/$', DownloadView.as_view()), url(r"^download/$", DownloadView.as_view()),
url(r'^text/$', TextView.as_view()), url(r"^text/$", TextView.as_view()),
url(r'^headers/$', HeadersView.as_view()), url(r"^headers/$", HeadersView.as_view()),
] ]
@unittest.skipUnless(coreapi, 'coreapi not installed') @unittest.skipUnless(coreapi, "coreapi not installed")
@override_settings(ROOT_URLCONF='tests.test_api_client') @override_settings(ROOT_URLCONF="tests.test_api_client")
class APIClientTests(APITestCase): class APIClientTests(APITestCase):
def test_api_client(self): def test_api_client(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
assert schema.title == 'Example API' assert schema.title == "Example API"
assert schema.url == 'https://api.example.com/' assert schema.url == "https://api.example.com/"
assert schema['simple_link'].description == 'example link' assert schema["simple_link"].description == "example link"
assert schema['location']['query'].fields[0].schema.description == 'example field' assert (
data = client.action(schema, ['simple_link']) schema["location"]["query"].fields[0].schema.description == "example field"
expected = { )
'method': 'GET', data = client.action(schema, ["simple_link"])
'query_params': {} expected = {"method": "GET", "query_params": {}}
}
assert data == expected assert data == expected
def test_query_params(self): def test_query_params(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['location', 'query'], params={'example': 123}) data = client.action(schema, ["location", "query"], params={"example": 123})
expected = { expected = {"method": "GET", "query_params": {"example": "123"}}
'method': 'GET',
'query_params': {'example': '123'}
}
assert data == expected assert data == expected
def test_session_headers(self): def test_session_headers(self):
client = CoreAPIClient() client = CoreAPIClient()
client.session.headers.update({'X-Custom-Header': 'foo'}) client.session.headers.update({"X-Custom-Header": "foo"})
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['headers']) data = client.action(schema, ["headers"])
assert data['headers']['X-CUSTOM-HEADER'] == 'foo' assert data["headers"]["X-CUSTOM-HEADER"] == "foo"
def test_query_params_with_multiple_values(self): def test_query_params_with_multiple_values(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) data = client.action(
expected = { schema, ["location", "query"], params={"example": [1, 2, 3]}
'method': 'GET', )
'query_params': {'example': ['1', '2', '3']} expected = {"method": "GET", "query_params": {"example": ["1", "2", "3"]}}
}
assert data == expected assert data == expected
def test_form_params(self): def test_form_params(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['location', 'form'], params={'example': 123}) data = client.action(schema, ["location", "form"], params={"example": 123})
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'application/json', "content_type": "application/json",
'query_params': {}, "query_params": {},
'data': {'example': 123}, "data": {"example": 123},
'files': {} "files": {},
} }
assert data == expected assert data == expected
def test_body_params(self): def test_body_params(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['location', 'body'], params={'example': 123}) data = client.action(schema, ["location", "body"], params={"example": 123})
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'application/json', "content_type": "application/json",
'query_params': {}, "query_params": {},
'data': 123, "data": 123,
'files': {} "files": {},
} }
assert data == expected assert data == expected
def test_path_params(self): def test_path_params(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['location', 'path'], params={'id': 123}) data = client.action(schema, ["location", "path"], params={"id": 123})
expected = { expected = {"method": "GET", "query_params": {}, "id": "123"}
'method': 'GET',
'query_params': {},
'id': '123'
}
assert data == expected assert data == expected
def test_multipart_encoding(self): def test_multipart_encoding(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile() as temp: with tempfile.NamedTemporaryFile() as temp:
temp.write(b'example file content') temp.write(b"example file content")
temp.flush() temp.flush()
temp.seek(0) temp.seek(0)
name = os.path.basename(temp.name) name = os.path.basename(temp.name)
data = client.action(schema, ['encoding', 'multipart'], params={'example': temp}) data = client.action(
schema, ["encoding", "multipart"], params={"example": temp}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'multipart/form-data', "content_type": "multipart/form-data",
'query_params': {}, "query_params": {},
'data': {}, "data": {},
'files': {'example': {'name': name, 'content': 'example file content'}} "files": {"example": {"name": name, "content": "example file content"}},
} }
assert data == expected assert data == expected
def test_multipart_encoding_no_file(self): def test_multipart_encoding_no_file(self):
# When no file is included, multipart encoding should still be used. # When no file is included, multipart encoding should still be used.
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) data = client.action(schema, ["encoding", "multipart"], params={"example": 123})
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'multipart/form-data', "content_type": "multipart/form-data",
'query_params': {}, "query_params": {},
'data': {'example': '123'}, "data": {"example": "123"},
'files': {} "files": {},
} }
assert data == expected assert data == expected
def test_multipart_encoding_multiple_values(self): def test_multipart_encoding_multiple_values(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) data = client.action(
schema, ["encoding", "multipart"], params={"example": [1, 2, 3]}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'multipart/form-data', "content_type": "multipart/form-data",
'query_params': {}, "query_params": {},
'data': {'example': ['1', '2', '3']}, "data": {"example": ["1", "2", "3"]},
'files': {} "files": {},
} }
assert data == expected assert data == expected
@ -328,17 +345,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File from coreapi.utils import File
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
example = File(name='example.txt', content='123') example = File(name="example.txt", content="123")
data = client.action(schema, ['encoding', 'multipart'], params={'example': example}) data = client.action(
schema, ["encoding", "multipart"], params={"example": example}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'multipart/form-data', "content_type": "multipart/form-data",
'query_params': {}, "query_params": {},
'data': {}, "data": {},
'files': {'example': {'name': 'example.txt', 'content': '123'}} "files": {"example": {"name": "example.txt", "content": "123"}},
} }
assert data == expected assert data == expected
@ -346,17 +365,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File from coreapi.utils import File
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} example = {"foo": File(name="example.txt", content="123"), "bar": "abc"}
data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example}) data = client.action(
schema, ["encoding", "multipart-body"], params={"example": example}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'multipart/form-data', "content_type": "multipart/form-data",
'query_params': {}, "query_params": {},
'data': {'bar': 'abc'}, "data": {"bar": "abc"},
'files': {'foo': {'name': 'example.txt', 'content': '123'}} "files": {"foo": {"name": "example.txt", "content": "123"}},
} }
assert data == expected assert data == expected
@ -364,40 +385,48 @@ class APIClientTests(APITestCase):
def test_urlencoded_encoding(self): def test_urlencoded_encoding(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) data = client.action(
schema, ["encoding", "urlencoded"], params={"example": 123}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'application/x-www-form-urlencoded', "content_type": "application/x-www-form-urlencoded",
'query_params': {}, "query_params": {},
'data': {'example': '123'}, "data": {"example": "123"},
'files': {} "files": {},
} }
assert data == expected assert data == expected
def test_urlencoded_encoding_multiple_values(self): def test_urlencoded_encoding_multiple_values(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) data = client.action(
schema, ["encoding", "urlencoded"], params={"example": [1, 2, 3]}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'application/x-www-form-urlencoded', "content_type": "application/x-www-form-urlencoded",
'query_params': {}, "query_params": {},
'data': {'example': ['1', '2', '3']}, "data": {"example": ["1", "2", "3"]},
'files': {} "files": {},
} }
assert data == expected assert data == expected
def test_urlencoded_encoding_in_body(self): def test_urlencoded_encoding_in_body(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) data = client.action(
schema,
["encoding", "urlencoded-body"],
params={"example": {"foo": 123, "bar": True}},
)
expected = { expected = {
'method': 'POST', "method": "POST",
'content_type': 'application/x-www-form-urlencoded', "content_type": "application/x-www-form-urlencoded",
'query_params': {}, "query_params": {},
'data': {'foo': '123', 'bar': 'true'}, "data": {"foo": "123", "bar": "true"},
'files': {} "files": {},
} }
assert data == expected assert data == expected
@ -405,20 +434,22 @@ class APIClientTests(APITestCase):
def test_raw_upload(self): def test_raw_upload(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile(delete=False) as temp: with tempfile.NamedTemporaryFile(delete=False) as temp:
temp.write(b'example file content') temp.write(b"example file content")
temp.flush() temp.flush()
temp.seek(0) temp.seek(0)
name = os.path.basename(temp.name) name = os.path.basename(temp.name)
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': temp}) data = client.action(
schema, ["encoding", "raw_upload"], params={"example": temp}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'files': {'file': {'name': name, 'content': 'example file content'}}, "files": {"file": {"name": name, "content": "example file content"}},
'content_type': 'application/octet-stream' "content_type": "application/octet-stream",
} }
assert data == expected assert data == expected
@ -426,15 +457,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File from coreapi.utils import File
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
example = File('example.txt', '123') example = File("example.txt", "123")
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) data = client.action(
schema, ["encoding", "raw_upload"], params={"example": example}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'files': {'file': {'name': 'example.txt', 'content': '123'}}, "files": {"file": {"name": "example.txt", "content": "123"}},
'content_type': 'text/plain' "content_type": "text/plain",
} }
assert data == expected assert data == expected
@ -442,15 +475,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File from coreapi.utils import File
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
example = File('example.txt', '123', 'text/html') example = File("example.txt", "123", "text/html")
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) data = client.action(
schema, ["encoding", "raw_upload"], params={"example": example}
)
expected = { expected = {
'method': 'POST', "method": "POST",
'files': {'file': {'name': 'example.txt', 'content': '123'}}, "files": {"file": {"name": "example.txt", "content": "123"}},
'content_type': 'text/html' "content_type": "text/html",
} }
assert data == expected assert data == expected
@ -458,17 +493,17 @@ class APIClientTests(APITestCase):
def test_text_response(self): def test_text_response(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['response', 'text']) data = client.action(schema, ["response", "text"])
expected = '123' expected = "123"
assert data == expected assert data == expected
def test_download_response(self): def test_download_response(self):
client = CoreAPIClient() client = CoreAPIClient()
schema = client.get('http://api.example.com/') schema = client.get("http://api.example.com/")
data = client.action(schema, ['response', 'download']) data = client.action(schema, ["response", "download"])
assert data.basename == 'download.png' assert data.basename == "download.png"
assert data.read() == b'some file content' assert data.read() == b"some file content"

View File

@ -14,13 +14,14 @@ from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
from tests.models import BasicModel from tests.models import BasicModel
factory = APIRequestFactory() factory = APIRequestFactory()
class BasicView(APIView): class BasicView(APIView):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
BasicModel.objects.create() BasicModel.objects.create()
return Response({'method': 'GET'}) return Response({"method": "GET"})
class ErrorView(APIView): class ErrorView(APIView):
@ -45,25 +46,23 @@ class NonAtomicAPIExceptionView(APIView):
raise Http404 raise Http404
urlpatterns = ( urlpatterns = (url(r"^$", NonAtomicAPIExceptionView.as_view()),)
url(r'^$', NonAtomicAPIExceptionView.as_view()),
)
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints.",
) )
class DBTransactionTests(TestCase): class DBTransactionTests(TestCase):
def setUp(self): def setUp(self):
self.view = BasicView.as_view() self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_no_exception_commit_transaction(self): def test_no_exception_commit_transaction(self):
request = factory.post('/') request = factory.post("/")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request) response = self.view(request)
@ -74,15 +73,15 @@ class DBTransactionTests(TestCase):
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints.",
) )
class DBTransactionErrorTests(TestCase): class DBTransactionErrorTests(TestCase):
def setUp(self): def setUp(self):
self.view = ErrorView.as_view() self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_generic_exception_delegate_transaction_management(self): def test_generic_exception_delegate_transaction_management(self):
""" """
@ -91,7 +90,7 @@ class DBTransactionErrorTests(TestCase):
We let django deal with the transaction when it will catch the Exception. We let django deal with the transaction when it will catch the Exception.
""" """
request = factory.post('/') request = factory.post("/")
with self.assertNumQueries(3): with self.assertNumQueries(3):
# 1 - begin savepoint # 1 - begin savepoint
# 2 - insert # 2 - insert
@ -104,21 +103,21 @@ class DBTransactionErrorTests(TestCase):
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints.",
) )
class DBTransactionAPIExceptionTests(TestCase): class DBTransactionAPIExceptionTests(TestCase):
def setUp(self): def setUp(self):
self.view = APIExceptionView.as_view() self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction(self): def test_api_exception_rollback_transaction(self):
""" """
Transaction is rollbacked by our transaction atomic block. Transaction is rollbacked by our transaction atomic block.
""" """
request = factory.post('/') request = factory.post("/")
num_queries = 4 if connection.features.can_release_savepoints else 3 num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries): with self.assertNumQueries(num_queries):
# 1 - begin savepoint # 1 - begin savepoint
@ -134,18 +133,18 @@ class DBTransactionAPIExceptionTests(TestCase):
@unittest.skipUnless( @unittest.skipUnless(
connection.features.uses_savepoints, connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints." "'atomic' requires transactions and savepoints.",
) )
@override_settings(ROOT_URLCONF='tests.test_atomic_requests') @override_settings(ROOT_URLCONF="tests.test_atomic_requests")
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
def setUp(self): def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self): def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction_non_atomic_view(self): def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/') response = self.client.get("/")
# without checking connection.in_atomic_block view raises 500 # without checking connection.in_atomic_block view raises 500
# due attempt to rollback without transaction # due attempt to rollback without transaction

View File

@ -6,44 +6,43 @@ from django.test import TestCase
from django.utils.six import StringIO from django.utils.six import StringIO
from rest_framework.authtoken.admin import TokenAdmin from rest_framework.authtoken.admin import TokenAdmin
from rest_framework.authtoken.management.commands.drf_create_token import \ from rest_framework.authtoken.management.commands.drf_create_token import (
Command as AuthTokenCommand Command as AuthTokenCommand,
)
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError from rest_framework.exceptions import ValidationError
class AuthTokenTests(TestCase): class AuthTokenTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username="test_user")
self.token = Token.objects.create(key='test token', user=self.user) self.token = Token.objects.create(key="test token", user=self.user)
def test_model_admin_displayed_fields(self): def test_model_admin_displayed_fields(self):
mock_request = object() mock_request = object()
token_admin = TokenAdmin(self.token, self.site) token_admin = TokenAdmin(self.token, self.site)
assert token_admin.get_fields(mock_request) == ('user',) assert token_admin.get_fields(mock_request) == ("user",)
def test_token_string_representation(self): def test_token_string_representation(self):
assert str(self.token) == 'test token' assert str(self.token) == "test token"
def test_validate_raise_error_if_no_credentials_provided(self): def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
AuthTokenSerializer().validate({}) AuthTokenSerializer().validate({})
def test_whitespace_in_password(self): def test_whitespace_in_password(self):
data = {'username': self.user.username, 'password': 'test pass '} data = {"username": self.user.username, "password": "test pass "}
self.user.set_password(data['password']) self.user.set_password(data["password"])
self.user.save() self.user.save()
assert AuthTokenSerializer(data=data).is_valid() assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase): class AuthTokenCommandTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') self.user = User.objects.create_user(username="test_user")
def test_command_create_user_token(self): def test_command_create_user_token(self):
token = AuthTokenCommand().create_user_token(self.user.username, False) token = AuthTokenCommand().create_user_token(self.user.username, False)
@ -53,7 +52,7 @@ class AuthTokenCommandTests(TestCase):
def test_command_create_user_token_invalid_user(self): def test_command_create_user_token_invalid_user(self):
with pytest.raises(User.DoesNotExist): with pytest.raises(User.DoesNotExist):
AuthTokenCommand().create_user_token('not_existing_user', False) AuthTokenCommand().create_user_token("not_existing_user", False)
def test_command_reset_user_token(self): def test_command_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False) AuthTokenCommand().create_user_token(self.user.username, False)
@ -74,12 +73,12 @@ class AuthTokenCommandTests(TestCase):
def test_command_raising_error_for_invalid_user(self): def test_command_raising_error_for_invalid_user(self):
out = StringIO() out = StringIO()
with pytest.raises(CommandError): with pytest.raises(CommandError):
call_command('drf_create_token', 'not_existing_user', stdout=out) call_command("drf_create_token", "not_existing_user", stdout=out)
def test_command_output(self): def test_command_output(self):
out = StringIO() out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out) call_command("drf_create_token", self.user.username, stdout=out)
token_saved = Token.objects.first() token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue()) self.assertIn("Generated token", out.getvalue())
self.assertIn(self.user.username, out.getvalue()) self.assertIn(self.user.username, out.getvalue())
self.assertIn(token_saved.key, out.getvalue()) self.assertIn(token_saved.key, out.getvalue())

View File

@ -11,41 +11,43 @@ class TestSimpleBoundField:
serializer = ExampleSerializer() serializer = ExampleSerializer()
assert serializer['text'].value == '' assert serializer["text"].value == ""
assert serializer['text'].errors is None assert serializer["text"].errors is None
assert serializer['text'].name == 'text' assert serializer["text"].name == "text"
assert serializer['amount'].value is None assert serializer["amount"].value is None
assert serializer['amount'].errors is None assert serializer["amount"].errors is None
assert serializer['amount'].name == 'amount' assert serializer["amount"].name == "amount"
def test_populated_bound_field(self): def test_populated_bound_field(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100) text = serializers.CharField(max_length=100)
amount = serializers.IntegerField() amount = serializers.IntegerField()
serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123}) serializer = ExampleSerializer(data={"text": "abc", "amount": 123})
assert serializer.is_valid() assert serializer.is_valid()
assert serializer['text'].value == 'abc' assert serializer["text"].value == "abc"
assert serializer['text'].errors is None assert serializer["text"].errors is None
assert serializer['text'].name == 'text' assert serializer["text"].name == "text"
assert serializer['amount'].value is 123 assert serializer["amount"].value is 123
assert serializer['amount'].errors is None assert serializer["amount"].errors is None
assert serializer['amount'].name == 'amount' assert serializer["amount"].name == "amount"
def test_error_bound_field(self): def test_error_bound_field(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100) text = serializers.CharField(max_length=100)
amount = serializers.IntegerField() amount = serializers.IntegerField()
serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123}) serializer = ExampleSerializer(data={"text": "x" * 1000, "amount": 123})
serializer.is_valid() serializer.is_valid()
assert serializer['text'].value == 'x' * 1000 assert serializer["text"].value == "x" * 1000
assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.'] assert serializer["text"].errors == [
assert serializer['text'].name == 'text' "Ensure this field has no more than 100 characters."
assert serializer['amount'].value is 123 ]
assert serializer['amount'].errors is None assert serializer["text"].name == "text"
assert serializer['amount'].name == 'amount' assert serializer["amount"].value is 123
assert serializer["amount"].errors is None
assert serializer["amount"].name == "amount"
def test_delete_field(self): def test_delete_field(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
@ -53,41 +55,45 @@ class TestSimpleBoundField:
amount = serializers.IntegerField() amount = serializers.IntegerField()
serializer = ExampleSerializer() serializer = ExampleSerializer()
del serializer.fields['text'] del serializer.fields["text"]
assert 'text' not in serializer.fields assert "text" not in serializer.fields
def test_as_form_fields(self): def test_as_form_fields(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField() bool_field = serializers.BooleanField()
null_field = serializers.IntegerField(allow_null=True) null_field = serializers.IntegerField(allow_null=True)
serializer = ExampleSerializer(data={'bool_field': False, 'null_field': None}) serializer = ExampleSerializer(data={"bool_field": False, "null_field": None})
assert serializer.is_valid() assert serializer.is_valid()
assert serializer['bool_field'].as_form_field().value == '' assert serializer["bool_field"].as_form_field().value == ""
assert serializer['null_field'].as_form_field().value == '' assert serializer["null_field"].as_form_field().value == ""
def test_rendering_boolean_field(self): def test_rendering_boolean_field(self):
from rest_framework.renderers import HTMLFormRenderer from rest_framework.renderers import HTMLFormRenderer
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField( bool_field = serializers.BooleanField(
style={'base_template': 'checkbox.html', 'template_pack': 'rest_framework/vertical'}) style={
"base_template": "checkbox.html",
"template_pack": "rest_framework/vertical",
}
)
serializer = ExampleSerializer(data={'bool_field': True}) serializer = ExampleSerializer(data={"bool_field": True})
assert serializer.is_valid() assert serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
rendered = renderer.render_field(serializer['bool_field'], {}) rendered = renderer.render_field(serializer["bool_field"], {})
expected_packed = ( expected_packed = (
'<divclass="form-group">' '<divclass="form-group">'
'<divclass="checkbox">' '<divclass="checkbox">'
'<label>' "<label>"
'<inputtype="checkbox"name="bool_field"value="true"checked>' '<inputtype="checkbox"name="bool_field"value="true"checked>'
'Boolfield' "Boolfield"
'</label>' "</label>"
'</div>' "</div>"
'</div>' "</div>"
) )
rendered_packed = ''.join(rendered.split()) rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed assert rendered_packed == expected_packed
@ -103,15 +109,15 @@ class TestNestedBoundField:
serializer = ExampleSerializer() serializer = ExampleSerializer()
assert serializer['text'].value == '' assert serializer["text"].value == ""
assert serializer['text'].errors is None assert serializer["text"].errors is None
assert serializer['text'].name == 'text' assert serializer["text"].name == "text"
assert serializer['nested']['more_text'].value == '' assert serializer["nested"]["more_text"].value == ""
assert serializer['nested']['more_text'].errors is None assert serializer["nested"]["more_text"].errors is None
assert serializer['nested']['more_text'].name == 'nested.more_text' assert serializer["nested"]["more_text"].name == "nested.more_text"
assert serializer['nested']['amount'].value is None assert serializer["nested"]["amount"].value is None
assert serializer['nested']['amount'].errors is None assert serializer["nested"]["amount"].errors is None
assert serializer['nested']['amount'].name == 'nested.amount' assert serializer["nested"]["amount"].name == "nested.amount"
def test_as_form_fields(self): def test_as_form_fields(self):
class Nested(serializers.Serializer): class Nested(serializers.Serializer):
@ -121,10 +127,12 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
nested = Nested() nested = Nested()
serializer = ExampleSerializer(data={'nested': {'bool_field': False, 'null_field': None}}) serializer = ExampleSerializer(
data={"nested": {"bool_field": False, "null_field": None}}
)
assert serializer.is_valid() assert serializer.is_valid()
assert serializer['nested']['bool_field'].as_form_field().value == '' assert serializer["nested"]["bool_field"].as_form_field().value == ""
assert serializer['nested']['null_field'].as_form_field().value == '' assert serializer["nested"]["null_field"].as_form_field().value == ""
def test_rendering_nested_fields_with_none_value(self): def test_rendering_nested_fields_with_none_value(self):
from rest_framework.renderers import HTMLFormRenderer from rest_framework.renderers import HTMLFormRenderer
@ -139,28 +147,30 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
nested2 = Nested2() nested2 = Nested2()
serializer = ExampleSerializer(data={'nested2': {'nested1': None, 'text_field': 'test'}}) serializer = ExampleSerializer(
data={"nested2": {"nested1": None, "text_field": "test"}}
)
assert serializer.is_valid() assert serializer.is_valid()
renderer = HTMLFormRenderer() renderer = HTMLFormRenderer()
for field in serializer: for field in serializer:
rendered = renderer.render_field(field, {}) rendered = renderer.render_field(field, {})
expected_packed = ( expected_packed = (
'<fieldset>' "<fieldset>"
'<legend>Nested2</legend>' "<legend>Nested2</legend>"
'<fieldset>' "<fieldset>"
'<legend>Nested1</legend>' "<legend>Nested1</legend>"
'<divclass="form-group">' '<divclass="form-group">'
'<label>Textfield</label>' "<label>Textfield</label>"
'<inputname="nested2.nested1.text_field"class="form-control"type="text"value="">' '<inputname="nested2.nested1.text_field"class="form-control"type="text"value="">'
'</div>' "</div>"
'</fieldset>' "</fieldset>"
'<divclass="form-group">' '<divclass="form-group">'
'<label>Textfield</label>' "<label>Textfield</label>"
'<inputname="nested2.text_field"class="form-control"type="text"value="test">' '<inputname="nested2.text_field"class="form-control"type="text"value="test">'
'</div>' "</div>"
'</fieldset>' "</fieldset>"
) )
rendered_packed = ''.join(rendered.split()) rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed assert rendered_packed == expected_packed
@ -170,7 +180,7 @@ class TestJSONBoundField:
json_field = serializers.JSONField() json_field = serializers.JSONField()
data = QueryDict(mutable=True) data = QueryDict(mutable=True)
data.update({'json_field': '{"some": ["json"}'}) data.update({"json_field": '{"some": ["json"}'})
serializer = TestSerializer(data=data) serializer = TestSerializer(data=data)
assert serializer.is_valid() is False assert serializer.is_valid() is False
assert serializer['json_field'].as_form_field().value == '{"some": ["json"}' assert serializer["json_field"].as_form_field().value == '{"some": ["json"}'

View File

@ -6,9 +6,16 @@ from django.test import TestCase
from rest_framework import RemovedInDRF310Warning, status from rest_framework import RemovedInDRF310Warning, status
from rest_framework.authentication import BasicAuthentication from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import ( from rest_framework.decorators import (
action, api_view, authentication_classes, detail_route, list_route, action,
parser_classes, permission_classes, renderer_classes, schema, api_view,
throttle_classes authentication_classes,
detail_route,
list_route,
parser_classes,
permission_classes,
renderer_classes,
schema,
throttle_classes,
) )
from rest_framework.parsers import JSONParser from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
@ -21,7 +28,6 @@ from rest_framework.views import APIView
class DecoratorTestCase(TestCase): class DecoratorTestCase(TestCase):
def setUp(self): def setUp(self):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
@ -38,7 +44,7 @@ class DecoratorTestCase(TestCase):
def view(request): def view(request):
return Response() return Response()
request = self.factory.get('/') request = self.factory.get("/")
self.assertRaises(AssertionError, view, request) self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self): def test_api_view_incorrect_arguments(self):
@ -47,108 +53,102 @@ class DecoratorTestCase(TestCase):
""" """
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
@api_view('GET')
@api_view("GET")
def view(request): def view(request):
return Response() return Response()
def test_calling_method(self): def test_calling_method(self):
@api_view(["GET"])
@api_view(['GET'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/') request = self.factory.post("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self): def test_calling_put_method(self):
@api_view(["GET", "PUT"])
@api_view(['GET', 'PUT'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.put('/') request = self.factory.put("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/') request = self.factory.post("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self): def test_calling_patch_method(self):
@api_view(["GET", "PATCH"])
@api_view(['GET', 'PATCH'])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.patch('/') request = self.factory.patch("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/') request = self.factory.post("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self): def test_renderer_classes(self):
@api_view(["GET"])
@api_view(['GET'])
@renderer_classes([JSONRenderer]) @renderer_classes([JSONRenderer])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
response = view(request) response = view(request)
assert isinstance(response.accepted_renderer, JSONRenderer) assert isinstance(response.accepted_renderer, JSONRenderer)
def test_parser_classes(self): def test_parser_classes(self):
@api_view(["GET"])
@api_view(['GET'])
@parser_classes([JSONParser]) @parser_classes([JSONParser])
def view(request): def view(request):
assert len(request.parsers) == 1 assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser) assert isinstance(request.parsers[0], JSONParser)
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
view(request) view(request)
def test_authentication_classes(self): def test_authentication_classes(self):
@api_view(["GET"])
@api_view(['GET'])
@authentication_classes([BasicAuthentication]) @authentication_classes([BasicAuthentication])
def view(request): def view(request):
assert len(request.authenticators) == 1 assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication) assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
view(request) view(request)
def test_permission_classes(self): def test_permission_classes(self):
@api_view(["GET"])
@api_view(['GET'])
@permission_classes([IsAuthenticated]) @permission_classes([IsAuthenticated])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN assert response.status_code == status.HTTP_403_FORBIDDEN
def test_throttle_classes(self): def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle): class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day' rate = "1/day"
@api_view(['GET']) @api_view(["GET"])
@throttle_classes([OncePerDayUserThrottle]) @throttle_classes([OncePerDayUserThrottle])
def view(request): def view(request):
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get("/")
response = view(request) response = view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -159,10 +159,11 @@ class DecoratorTestCase(TestCase):
""" """
Checks CustomSchema class is set on view Checks CustomSchema class is set on view
""" """
class CustomSchema(AutoSchema): class CustomSchema(AutoSchema):
pass pass
@api_view(['GET']) @api_view(["GET"])
@schema(CustomSchema()) @schema(CustomSchema())
def view(request): def view(request):
return Response({}) return Response({})
@ -171,23 +172,23 @@ class DecoratorTestCase(TestCase):
class ActionDecoratorTestCase(TestCase): class ActionDecoratorTestCase(TestCase):
def test_defaults(self): def test_defaults(self):
@action(detail=True) @action(detail=True)
def test_action(request): def test_action(request):
"""Description""" """Description"""
assert test_action.mapping == {'get': 'test_action'} assert test_action.mapping == {"get": "test_action"}
assert test_action.detail is True assert test_action.detail is True
assert test_action.url_path == 'test_action' assert test_action.url_path == "test_action"
assert test_action.url_name == 'test-action' assert test_action.url_name == "test-action"
assert test_action.kwargs == { assert test_action.kwargs == {
'name': 'Test action', "name": "Test action",
'description': 'Description', "description": "Description",
} }
def test_detail_required(self): def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
@action() @action()
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
@ -201,6 +202,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
for name in APIView.http_method_names: for name in APIView.http_method_names:
def method(): def method():
raise NotImplementedError raise NotImplementedError
@ -222,36 +224,30 @@ class ActionDecoratorTestCase(TestCase):
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {"description": None, "name": "Test action"}
'description': None,
'name': 'Test action',
}
# name kwarg supersedes name generation # name kwarg supersedes name generation
@action(detail=True, name='test name') @action(detail=True, name="test name")
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {"description": None, "name": "test name"}
'description': None,
'name': 'test name',
}
# suffix kwarg supersedes name generation # suffix kwarg supersedes name generation
@action(detail=True, suffix='Suffix') @action(detail=True, suffix="Suffix")
def test_action(request): def test_action(request):
raise NotImplementedError raise NotImplementedError
assert test_action.kwargs == { assert test_action.kwargs == {"description": None, "suffix": "Suffix"}
'description': None,
'suffix': 'Suffix',
}
# name + suffix is a conflict. # name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
action(detail=True, name='test name', suffix='Suffix') action(detail=True, name="test name", suffix="Suffix")
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments." assert (
str(excinfo.value)
== "`name` and `suffix` are mutually exclusive arguments."
)
def test_method_mapping(self): def test_method_mapping(self):
@action(detail=False) @action(detail=False)
@ -263,7 +259,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError raise NotImplementedError
# The secondary handler methods should not have the action attributes # The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']: for name in ["mapping", "detail", "url_path", "url_name", "kwargs"]:
assert hasattr(test_action, name) and not hasattr(test_action_post, name) assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self): def test_method_mapping_already_mapped(self):
@ -273,6 +269,7 @@ class ActionDecoratorTestCase(TestCase):
msg = "Method 'get' has already been mapped to '.test_action'." msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg): with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.get @test_action.mapping.get
def test_action_get(request): def test_action_get(request):
raise NotImplementedError raise NotImplementedError
@ -282,15 +279,19 @@ class ActionDecoratorTestCase(TestCase):
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You " msg = (
"cannot use the same method name for each mapping declaration.") "Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration."
)
with self.assertRaisesMessage(AssertionError, msg): with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post @test_action.mapping.post
def test_action(): def test_action():
raise NotImplementedError raise NotImplementedError
def test_detail_route_deprecation(self): def test_detail_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@detail_route() @detail_route()
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
@ -304,6 +305,7 @@ class ActionDecoratorTestCase(TestCase):
def test_list_route_deprecation(self): def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record: with pytest.warns(RemovedInDRF310Warning) as record:
@list_route() @list_route()
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
@ -318,9 +320,10 @@ class ActionDecoratorTestCase(TestCase):
def test_route_url_name_from_path(self): def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path` # pre-3.8 behavior was to base the `url_name` off of the `url_path`
with pytest.warns(RemovedInDRF310Warning): with pytest.warns(RemovedInDRF310Warning):
@list_route(url_path='foo_bar')
@list_route(url_path="foo_bar")
def view(request): def view(request):
raise NotImplementedError raise NotImplementedError
assert view.url_path == 'foo_bar' assert view.url_path == "foo_bar"
assert view.url_name == 'foo-bar' assert view.url_name == "foo-bar"

View File

@ -9,6 +9,7 @@ from rest_framework.compat import apply_markdown
from rest_framework.utils.formatting import dedent from rest_framework.utils.formatting import dedent
from rest_framework.views import APIView from rest_framework.views import APIView
# We check that docstrings get nicely un-indented. # We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring DESCRIPTION = """an example docstring
==================== ====================
@ -81,28 +82,34 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
Ensure view names are based on the class name. Ensure view names are based on the class name.
""" """
class MockView(APIView): class MockView(APIView):
pass pass
assert MockView().get_view_name() == 'Mock'
assert MockView().get_view_name() == "Mock"
def test_view_name_uses_name_attribute(self): def test_view_name_uses_name_attribute(self):
class MockView(APIView): class MockView(APIView):
name = 'Foo' name = "Foo"
assert MockView().get_view_name() == 'Foo'
assert MockView().get_view_name() == "Foo"
def test_view_name_uses_suffix_attribute(self): def test_view_name_uses_suffix_attribute(self):
class MockView(APIView): class MockView(APIView):
suffix = 'List' suffix = "List"
assert MockView().get_view_name() == 'Mock List'
assert MockView().get_view_name() == "Mock List"
def test_view_name_preferences_name_over_suffix(self): def test_view_name_preferences_name_over_suffix(self):
class MockView(APIView): class MockView(APIView):
name = 'Foo' name = "Foo"
suffix = 'List' suffix = "List"
assert MockView().get_view_name() == 'Foo'
assert MockView().get_view_name() == "Foo"
def test_view_description_uses_docstring(self): def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring.""" """Ensure view descriptions are based on the docstring."""
class MockView(APIView): class MockView(APIView):
"""an example docstring """an example docstring
==================== ====================
@ -130,23 +137,28 @@ class TestViewNamesAndDescriptions(TestCase):
def test_view_description_uses_description_attribute(self): def test_view_description_uses_description_attribute(self):
class MockView(APIView): class MockView(APIView):
description = 'Foo' description = "Foo"
assert MockView().get_view_description() == 'Foo'
assert MockView().get_view_description() == "Foo"
def test_view_description_allows_empty_description(self): def test_view_description_allows_empty_description(self):
class MockView(APIView): class MockView(APIView):
"""Description.""" """Description."""
description = ''
assert MockView().get_view_description() == '' description = ""
assert MockView().get_view_description() == ""
def test_view_description_can_be_empty(self): def test_view_description_can_be_empty(self):
""" """
Ensure that if a view has no docstring, Ensure that if a view has no docstring,
then it's description is the empty string. then it's description is the empty string.
""" """
class MockView(APIView): class MockView(APIView):
pass pass
assert MockView().get_view_description() == ''
assert MockView().get_view_description() == ""
def test_view_description_can_be_promise(self): def test_view_description_can_be_promise(self):
""" """
@ -168,7 +180,7 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView): class MockView(APIView):
__doc__ = MockLazyStr("a gettext string") __doc__ = MockLazyStr("a gettext string")
assert MockView().get_view_description() == 'a gettext string' assert MockView().get_view_description() == "a gettext string"
def test_markdown(self): def test_markdown(self):
""" """
@ -176,21 +188,17 @@ class TestViewNamesAndDescriptions(TestCase):
""" """
if apply_markdown: if apply_markdown:
md_applied = apply_markdown(DESCRIPTION) md_applied = apply_markdown(DESCRIPTION)
gte_21_match = ( gte_21_match = md_applied == (
md_applied == ( MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or ) or md_applied == (MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)
md_applied == ( lt_21_match = md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)) MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE
lt_21_match = ( ) or md_applied == (MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE)
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
assert gte_21_match or lt_21_match assert gte_21_match or lt_21_match
def test_dedent_tabs(): def test_dedent_tabs():
result = 'first string\n\nsecond string' result = "first string\n\nsecond string"
assert dedent(" first string\n\n second string") == result assert dedent(" first string\n\n second string") == result
assert dedent("first string\n\n second string") == result assert dedent("first string\n\n second string") == result
assert dedent("\tfirst string\n\n\tsecond string") == result assert dedent("\tfirst string\n\n\tsecond string") == result

View File

@ -37,7 +37,7 @@ class JSONEncoderTests(TestCase):
current_time = datetime.now() current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat() assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc) current_time_utc = current_time.replace(tzinfo=utc)
assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z' assert self.encoder.default(current_time_utc) == current_time.isoformat() + "Z"
def test_encode_time(self): def test_encode_time(self):
""" """
@ -76,7 +76,7 @@ class JSONEncoderTests(TestCase):
unique_id = uuid4() unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id) assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_encode_coreapi_raises_error(self): def test_encode_coreapi_raises_error(self):
""" """
Tests encoding a coreapi objects raises proper error Tests encoding a coreapi objects raises proper error

View File

@ -6,13 +6,16 @@ from django.utils import six, translation
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import ( from rest_framework.exceptions import (
APIException, ErrorDetail, Throttled, _get_error_details, bad_request, APIException,
server_error ErrorDetail,
Throttled,
_get_error_details,
bad_request,
server_error,
) )
class ExceptionTestCase(TestCase): class ExceptionTestCase(TestCase):
def test_get_error_details(self): def test_get_error_details(self):
example = "string" example = "string"
@ -20,91 +23,96 @@ class ExceptionTestCase(TestCase):
assert _get_error_details(lazy_example) == example assert _get_error_details(lazy_example) == example
assert isinstance( assert isinstance(_get_error_details(lazy_example), ErrorDetail)
_get_error_details(lazy_example),
ErrorDetail
)
assert _get_error_details({'nested': lazy_example})['nested'] == example assert _get_error_details({"nested": lazy_example})["nested"] == example
assert isinstance( assert isinstance(
_get_error_details({'nested': lazy_example})['nested'], _get_error_details({"nested": lazy_example})["nested"], ErrorDetail
ErrorDetail
) )
assert _get_error_details([[lazy_example]])[0][0] == example assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance( assert isinstance(_get_error_details([[lazy_example]])[0][0], ErrorDetail)
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
def test_get_full_details_with_throttling(self): def test_get_full_details_with_throttling(self):
exception = Throttled() exception = Throttled()
assert exception.get_full_details() == { assert exception.get_full_details() == {
'message': 'Request was throttled.', 'code': 'throttled'} "message": "Request was throttled.",
"code": "throttled",
}
exception = Throttled(wait=2) exception = Throttled(wait=2)
assert exception.get_full_details() == { assert exception.get_full_details() == {
'message': 'Request was throttled. Expected available in {} seconds.'.format(2 if six.PY3 else 2.), "message": "Request was throttled. Expected available in {} seconds.".format(
'code': 'throttled'} 2 if six.PY3 else 2.0
),
"code": "throttled",
}
exception = Throttled(wait=2, detail='Slow down!') exception = Throttled(wait=2, detail="Slow down!")
assert exception.get_full_details() == { assert exception.get_full_details() == {
'message': 'Slow down! Expected available in {} seconds.'.format(2 if six.PY3 else 2.), "message": "Slow down! Expected available in {} seconds.".format(
'code': 'throttled'} 2 if six.PY3 else 2.0
),
"code": "throttled",
}
class ErrorDetailTests(TestCase): class ErrorDetailTests(TestCase):
def test_eq(self): def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg') assert ErrorDetail("msg") == ErrorDetail("msg")
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code') assert ErrorDetail("msg", "code") == ErrorDetail("msg", code="code")
assert ErrorDetail('msg') == 'msg' assert ErrorDetail("msg") == "msg"
assert ErrorDetail('msg', 'code') == 'msg' assert ErrorDetail("msg", "code") == "msg"
def test_ne(self): def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2') assert ErrorDetail("msg1") != ErrorDetail("msg2")
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid') assert ErrorDetail("msg") != ErrorDetail("msg", code="invalid")
assert ErrorDetail('msg1') != 'msg2' assert ErrorDetail("msg1") != "msg2"
assert ErrorDetail('msg1', 'code') != 'msg2' assert ErrorDetail("msg1", "code") != "msg2"
def test_repr(self): def test_repr(self):
assert repr(ErrorDetail('msg1')) == \ assert repr(
'ErrorDetail(string={!r}, code=None)'.format('msg1') ErrorDetail("msg1")
assert repr(ErrorDetail('msg1', 'code')) == \ ) == "ErrorDetail(string={!r}, code=None)".format("msg1")
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code') assert repr(
ErrorDetail("msg1", "code")
) == "ErrorDetail(string={!r}, code={!r})".format("msg1", "code")
def test_str(self): def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1' assert str(ErrorDetail("msg1")) == "msg1"
assert str(ErrorDetail('msg1', 'code')) == 'msg1' assert str(ErrorDetail("msg1", "code")) == "msg1"
def test_hash(self): def test_hash(self):
assert hash(ErrorDetail('msg')) == hash('msg') assert hash(ErrorDetail("msg")) == hash("msg")
assert hash(ErrorDetail('msg', 'code')) == hash('msg') assert hash(ErrorDetail("msg", "code")) == hash("msg")
class TranslationTests(TestCase): class TranslationTests(TestCase):
@translation.override("fr")
@translation.override('fr')
def test_message(self): def test_message(self):
# this test largely acts as a sanity test to ensure the translation files are present. # this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.') self.assertEqual(
self.assertEqual(six.text_type(APIException()), 'Une erreur du serveur est survenue.') _("A server error occurred."), "Une erreur du serveur est survenue."
)
self.assertEqual(
six.text_type(APIException()), "Une erreur du serveur est survenue."
)
def test_server_error(): def test_server_error():
request = RequestFactory().get('/') request = RequestFactory().get("/")
response = server_error(request) response = server_error(request)
assert response.status_code == 500 assert response.status_code == 500
assert response["content-type"] == 'application/json' assert response["content-type"] == "application/json"
def test_bad_request(): def test_bad_request():
request = RequestFactory().get('/') request = RequestFactory().get("/")
exception = Exception('Something went wrong — Not used') exception = Exception("Something went wrong — Not used")
response = bad_request(request, exception) response = bad_request(request, exception)
assert response.status_code == 400 assert response.status_code == 400
assert response["content-type"] == 'application/json' assert response["content-type"] == "application/json"

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from rest_framework import filters, generics, serializers
from rest_framework.compat import coreschema from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
@ -30,7 +31,7 @@ class BaseFilterTests(TestCase):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None) self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed') @pytest.mark.skipif(not coreschema, reason="coreschema is not installed")
def test_get_schema_fields_checks_for_coreapi(self): def test_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None filters.coreapi = None
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
@ -47,7 +48,7 @@ class SearchFilterModel(models.Model):
class SearchFilterSerializer(serializers.ModelSerializer): class SearchFilterSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SearchFilterModel model = SearchFilterModel
fields = '__all__' fields = "__all__"
class SearchFilterTests(TestCase): class SearchFilterTests(TestCase):
@ -59,12 +60,8 @@ class SearchFilterTests(TestCase):
# zzz cde # zzz cde
# ... # ...
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = "z" * (idx + 1)
text = ( text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModel(title=title, text=text).save() SearchFilterModel(title=title, text=text).save()
def test_search(self): def test_search(self):
@ -72,14 +69,14 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text') search_fields = ("title", "text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'}) request = factory.get("/", {"search": "b"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'}, {"id": 1, "title": "z", "text": "abc"},
{'id': 2, 'title': 'zz', 'text': 'bcd'} {"id": 2, "title": "zz", "text": "bcd"},
] ]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self): def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
@ -89,10 +86,11 @@ class SearchFilterTests(TestCase):
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/') request = factory.get("/")
response = view(request) response = view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(), expected = SearchFilterSerializer(
many=True).data SearchFilterModel.objects.all(), many=True
).data
assert response.data == expected assert response.data == expected
def test_exact_search(self): def test_exact_search(self):
@ -100,59 +98,53 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text') search_fields = ("=title", "text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'zzz'}) request = factory.get("/", {"search": "zzz"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
def test_startswith_search(self): def test_startswith_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text') search_fields = ("title", "^text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'}) request = factory.get("/", {"search": "b"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_regexp_search(self): def test_regexp_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text') search_fields = ("$title", "$text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'z{2} ^b'}) request = factory.get("/", {"search": "z{2} ^b"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
def test_search_with_nonstandard_search_param(self): def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}): with override_settings(REST_FRAMEWORK={"SEARCH_PARAM": "query"}):
reload_module(filters) reload_module(filters)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text') search_fields = ("title", "text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'query': 'b'}) request = factory.get("/", {"query": "b"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'}, {"id": 1, "title": "z", "text": "abc"},
{'id': 2, 'title': 'zz', 'text': 'bcd'} {"id": 2, "title": "zz", "text": "bcd"},
] ]
reload_module(filters) reload_module(filters)
@ -161,26 +153,24 @@ class SearchFilterTests(TestCase):
class CustomSearchFilter(filters.SearchFilter): class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields # Filter that dynamically changes search fields
def get_search_fields(self, view, request): def get_search_fields(self, view, request):
if request.query_params.get('title_only'): if request.query_params.get("title_only"):
return ('$title',) return ("$title",)
return super(CustomSearchFilter, self).get_search_fields(view, request) return super(CustomSearchFilter, self).get_search_fields(view, request)
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all() queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,) filter_backends = (CustomSearchFilter,)
search_fields = ('$title', '$text') search_fields = ("$title", "$text")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': r'^\w{3}$'}) request = factory.get("/", {"search": r"^\w{3}$"})
response = view(request) response = view(request)
assert len(response.data) == 10 assert len(response.data) == 10
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'}) request = factory.get("/", {"search": r"^\w{3}$", "title_only": "true"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
class AttributeModel(models.Model): class AttributeModel(models.Model):
@ -195,33 +185,31 @@ class SearchFilterModelFk(models.Model):
class SearchFilterFkSerializer(serializers.ModelSerializer): class SearchFilterFkSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SearchFilterModelFk model = SearchFilterModelFk
fields = '__all__' fields = "__all__"
class SearchFilterFkTests(TestCase): class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct(
SearchFilterModelFk._meta, SearchFilterModelFk._meta, ["%stitle" % prefix]
["%stitle" % prefix]
) )
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct(
SearchFilterModelFk._meta, SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix] ["%stitle" % prefix, "%sattribute__label" % prefix],
) )
def test_must_call_distinct_restores_meta_for_each_field(self): def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the # In this test case the attribute of the fk model comes first in the
# list of search fields. # list of search fields.
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct(
SearchFilterModelFk._meta, SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix] ["%sattribute__label" % prefix, "%stitle" % prefix],
) )
@ -234,7 +222,7 @@ class SearchFilterModelM2M(models.Model):
class SearchFilterM2MSerializer(serializers.ModelSerializer): class SearchFilterM2MSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SearchFilterModelM2M model = SearchFilterModelM2M
fields = '__all__' fields = "__all__"
class SearchFilterM2MTests(TestCase): class SearchFilterM2MTests(TestCase):
@ -246,43 +234,38 @@ class SearchFilterM2MTests(TestCase):
# zzz cde [1, 2, 3] # zzz cde [1, 2, 3]
# ... # ...
for idx in range(3): for idx in range(3):
label = 'w' * (idx + 1) label = "w" * (idx + 1)
AttributeModel.objects.create(label=label) AttributeModel.objects.create(label=label)
for idx in range(10): for idx in range(10):
title = 'z' * (idx + 1) title = "z" * (idx + 1)
text = ( text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
SearchFilterModelM2M(title=title, text=text).save() SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3) SearchFilterModelM2M.objects.get(title="zz").attributes.add(1, 2, 3)
def test_m2m_search(self): def test_m2m_search(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all() queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text', 'attributes__label') search_fields = ("=title", "text", "attributes__label")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'zz'}) request = factory.get("/", {"search": "zz"})
response = view(request) response = view(request)
assert len(response.data) == 1 assert len(response.data) == 1
def test_must_call_distinct(self): def test_must_call_distinct(self):
filter_ = filters.SearchFilter() filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes) prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes: for prefix in prefixes:
assert not filter_.must_call_distinct( assert not filter_.must_call_distinct(
SearchFilterModelM2M._meta, SearchFilterModelM2M._meta, ["%stitle" % prefix]
["%stitle" % prefix]
) )
assert filter_.must_call_distinct( assert filter_.must_call_distinct(
SearchFilterModelM2M._meta, SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix] ["%stitle" % prefix, "%sattributes__label" % prefix],
) )
@ -299,33 +282,46 @@ class Entry(models.Model):
class BlogSerializer(serializers.ModelSerializer): class BlogSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Blog model = Blog
fields = '__all__' fields = "__all__"
class SearchFilterToManyTests(TestCase): class SearchFilterToManyTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1') b1 = Blog.objects.create(name="Blog 1")
b2 = Blog.objects.create(name='Blog 2') b2 = Blog.objects.create(name="Blog 2")
# Multiple entries on Lennon published in 1979 - distinct should deduplicate # Multiple entries on Lennon published in 1979 - distinct should deduplicate
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1)) blog=b1,
headline="Something about Lennon",
pub_date=datetime.date(1979, 1, 1),
)
Entry.objects.create(
blog=b1,
headline="Another thing about Lennon",
pub_date=datetime.date(1979, 6, 1),
)
# Entry on Lennon *and* a separate entry in 1979 - should not match # Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1)) Entry.objects.create(
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1)) blog=b2, headline="Something unrelated", pub_date=datetime.date(1979, 1, 1)
)
Entry.objects.create(
blog=b2,
headline="Retrospective on Lennon",
pub_date=datetime.date(1990, 6, 1),
)
def test_multiple_filter_conditions(self): def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all() queryset = Blog.objects.all()
serializer_class = BlogSerializer serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year') search_fields = ("=name", "entry__headline", "=entry__pub_date__year")
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'}) request = factory.get("/", {"search": "Lennon,1979"})
response = view(request) response = view(request)
assert len(response.data) == 1 assert len(response.data) == 1
@ -335,60 +331,58 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SearchFilterModel model = SearchFilterModel
fields = ('title', 'text', 'title_text') fields = ("title", "text", "title_text")
class SearchFilterAnnotatedFieldTests(TestCase): class SearchFilterAnnotatedFieldTests(TestCase):
@classmethod @classmethod
def setUpTestData(cls): def setUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def') SearchFilterModel.objects.create(title="abc", text="def")
SearchFilterModel.objects.create(title='ghi', text='jkl') SearchFilterModel.objects.create(title="ghi", text="jkl")
def test_search_in_annotated_field(self): def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView): class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate( queryset = SearchFilterModel.objects.annotate(
title_text=Upper( title_text=Upper(Concat(models.F("title"), models.F("text")))
Concat(models.F('title'), models.F('text'))
)
).all() ).all()
serializer_class = SearchFilterAnnotatedSerializer serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,) filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',) search_fields = ("title_text",)
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('/', {'search': 'ABCDEF'}) request = factory.get("/", {"search": "ABCDEF"})
response = view(request) response = view(request)
assert len(response.data) == 1 assert len(response.data) == 1
assert response.data[0]['title_text'] == 'ABCDEF' assert response.data[0]["title_text"] == "ABCDEF"
class OrderingFilterModel(models.Model): class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20, verbose_name='verbose title') title = models.CharField(max_length=20, verbose_name="verbose title")
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
class OrderingFilterRelatedModel(models.Model): class OrderingFilterRelatedModel(models.Model):
related_object = models.ForeignKey(OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE) related_object = models.ForeignKey(
index = models.SmallIntegerField(help_text="A non-related field to test with", default=0) OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE
)
index = models.SmallIntegerField(
help_text="A non-related field to test with", default=0
)
class OrderingFilterSerializer(serializers.ModelSerializer): class OrderingFilterSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = OrderingFilterModel model = OrderingFilterModel
fields = '__all__' fields = "__all__"
class OrderingDottedRelatedSerializer(serializers.ModelSerializer): class OrderingDottedRelatedSerializer(serializers.ModelSerializer):
related_text = serializers.CharField(source='related_object.text') related_text = serializers.CharField(source="related_object.text")
related_title = serializers.CharField(source='related_object.title') related_title = serializers.CharField(source="related_object.title")
class Meta: class Meta:
model = OrderingFilterRelatedModel model = OrderingFilterRelatedModel
fields = ( fields = ("related_text", "related_title", "index")
'related_text',
'related_title',
'index',
)
class DjangoFilterOrderingModel(models.Model): class DjangoFilterOrderingModel(models.Model):
@ -396,13 +390,13 @@ class DjangoFilterOrderingModel(models.Model):
text = models.CharField(max_length=10) text = models.CharField(max_length=10)
class Meta: class Meta:
ordering = ['-date'] ordering = ["-date"]
class DjangoFilterOrderingSerializer(serializers.ModelSerializer): class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = DjangoFilterOrderingModel model = DjangoFilterOrderingModel
fields = '__all__' fields = "__all__"
class OrderingFilterTests(TestCase): class OrderingFilterTests(TestCase):
@ -413,16 +407,8 @@ class OrderingFilterTests(TestCase):
# yxw bcd # yxw bcd
# xwv cde # xwv cde
for idx in range(3): for idx in range(3):
title = ( title = chr(ord("z") - idx) + chr(ord("y") - idx) + chr(ord("x") - idx)
chr(ord('z') - idx) + text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
chr(ord('y') - idx) +
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
OrderingFilterModel(title=title, text=text).save() OrderingFilterModel(title=title, text=text).save()
def test_ordering(self): def test_ordering(self):
@ -430,16 +416,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request = factory.get("/", {"ordering": "text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
] ]
def test_reverse_ordering(self): def test_reverse_ordering(self):
@ -447,16 +433,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-text'}) request = factory.get("/", {"ordering": "-text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
] ]
def test_incorrecturl_extrahyphens_ordering(self): def test_incorrecturl_extrahyphens_ordering(self):
@ -464,16 +450,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '--text'}) request = factory.get("/", {"ordering": "--text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
] ]
def test_incorrectfield_ordering(self): def test_incorrectfield_ordering(self):
@ -481,16 +467,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'foobar'}) request = factory.get("/", {"ordering": "foobar"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
] ]
def test_default_ordering(self): def test_default_ordering(self):
@ -498,16 +484,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request = factory.get("")
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
] ]
def test_default_ordering_using_string(self): def test_default_ordering_using_string(self):
@ -515,53 +501,48 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = "title"
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('') request = factory.get("")
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
] ]
def test_ordering_by_aggregate_field(self): def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by # create some related models to aggregate order by
num_objs = [2, 5, 3] num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
num_objs):
for _ in range(num_relateds): for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel( new_related = OrderingFilterRelatedModel(related_object=obj)
related_object=obj
)
new_related.save() new_related.save()
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = 'title' ordering = "title"
ordering_fields = '__all__' ordering_fields = "__all__"
queryset = OrderingFilterModel.objects.all().annotate( queryset = OrderingFilterModel.objects.all().annotate(
models.Count("relateds")) models.Count("relateds")
)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'relateds__count'}) request = factory.get("/", {"ordering": "relateds__count"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
] ]
def test_ordering_by_dotted_source(self): def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()): for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create( OrderingFilterRelatedModel.objects.create(related_object=obj, index=index)
related_object=obj,
index=index
)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer serializer_class = OrderingDottedRelatedSerializer
@ -569,62 +550,62 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterRelatedModel.objects.all() queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'related_object__text'}) request = factory.get("/", {"ordering": "related_object__text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, {"related_title": "zyx", "related_text": "abc", "index": 0},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, {"related_title": "yxw", "related_text": "bcd", "index": 1},
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, {"related_title": "xwv", "related_text": "cde", "index": 2},
] ]
request = factory.get('/', {'ordering': '-index'}) request = factory.get("/", {"ordering": "-index"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2}, {"related_title": "xwv", "related_text": "cde", "index": 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1}, {"related_title": "yxw", "related_text": "bcd", "index": 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0}, {"related_title": "zyx", "related_text": "abc", "index": 0},
] ]
def test_ordering_with_nonstandard_ordering_param(self): def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}): with override_settings(REST_FRAMEWORK={"ORDERING_PARAM": "order"}):
reload_module(filters) reload_module(filters)
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
ordering_fields = ('text',) ordering_fields = ("text",)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'order': 'text'}) request = factory.get("/", {"order": "text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
] ]
reload_module(filters) reload_module(filters)
def test_get_template_context(self): def test_get_template_context(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
ordering_fields = '__all__' ordering_fields = "__all__"
serializer_class = OrderingFilterSerializer serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html') request = factory.get("/", {"ordering": "title"}, HTTP_ACCEPT="text/html")
view = OrderingListView.as_view() view = OrderingListView.as_view()
response = view(request) response = view(request)
self.assertContains(response, 'verbose title') self.assertContains(response, "verbose title")
def test_ordering_with_overridden_get_serializer_class(self): def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
# note: no ordering_fields and serializer_class specified # note: no ordering_fields and serializer_class specified
@ -632,24 +613,24 @@ class OrderingFilterTests(TestCase):
return OrderingFilterSerializer return OrderingFilterSerializer
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request = factory.get("/", {"ordering": "text"})
response = view(request) response = view(request)
assert response.data == [ assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'}, {"id": 1, "title": "zyx", "text": "abc"},
{'id': 2, 'title': 'yxw', 'text': 'bcd'}, {"id": 2, "title": "yxw", "text": "bcd"},
{'id': 3, 'title': 'xwv', 'text': 'cde'}, {"id": 3, "title": "xwv", "text": "cde"},
] ]
def test_ordering_with_improper_configuration(self): def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering = ('title',) ordering = ("title",)
# note: no ordering_fields and serializer_class # note: no ordering_fields and serializer_class
# or get_serializer_class specified # or get_serializer_class specified
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'}) request = factory.get("/", {"ordering": "text"})
with self.assertRaises(ImproperlyConfigured): with self.assertRaises(ImproperlyConfigured):
view(request) view(request)
@ -666,7 +647,7 @@ class SensitiveDataSerializer1(serializers.ModelSerializer):
class Meta: class Meta:
model = SensitiveOrderingFilterModel model = SensitiveOrderingFilterModel
fields = ('id', 'username') fields = ("id", "username")
class SensitiveDataSerializer2(serializers.ModelSerializer): class SensitiveDataSerializer2(serializers.ModelSerializer):
@ -675,74 +656,80 @@ class SensitiveDataSerializer2(serializers.ModelSerializer):
class Meta: class Meta:
model = SensitiveOrderingFilterModel model = SensitiveOrderingFilterModel
fields = ('id', 'username', 'password') fields = ("id", "username", "password")
class SensitiveDataSerializer3(serializers.ModelSerializer): class SensitiveDataSerializer3(serializers.ModelSerializer):
user = serializers.CharField(source='username') user = serializers.CharField(source="username")
class Meta: class Meta:
model = SensitiveOrderingFilterModel model = SensitiveOrderingFilterModel
fields = ('id', 'user') fields = ("id", "user")
class SensitiveOrderingFilterTests(TestCase): class SensitiveOrderingFilterTests(TestCase):
def setUp(self): def setUp(self):
for idx in range(3): for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx] username = {0: "userA", 1: "userB", 2: "userC"}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx] password = {0: "passA", 1: "passC", 2: "passB"}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save() SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self): def test_order_by_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [
SensitiveDataSerializer1, SensitiveDataSerializer1,
SensitiveDataSerializer2, SensitiveDataSerializer2,
SensitiveDataSerializer3 SensitiveDataSerializer3,
]: ]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by(
"username"
)
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-username'}) request = factory.get("/", {"ordering": "-username"})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
username_field = 'user' username_field = "user"
else: else:
username_field = 'username' username_field = "username"
# Note: Inverse username ordering correctly applied. # Note: Inverse username ordering correctly applied.
assert response.data == [ assert response.data == [
{'id': 3, username_field: 'userC'}, {"id": 3, username_field: "userC"},
{'id': 2, username_field: 'userB'}, {"id": 2, username_field: "userB"},
{'id': 1, username_field: 'userA'}, {"id": 1, username_field: "userA"},
] ]
def test_cannot_order_by_non_serializer_fields(self): def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [ for serializer_cls in [
SensitiveDataSerializer1, SensitiveDataSerializer1,
SensitiveDataSerializer2, SensitiveDataSerializer2,
SensitiveDataSerializer3 SensitiveDataSerializer3,
]: ]:
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username') queryset = SensitiveOrderingFilterModel.objects.all().order_by(
"username"
)
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'password'}) request = factory.get("/", {"ordering": "password"})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
username_field = 'user' username_field = "user"
else: else:
username_field = 'username' username_field = "username"
# Note: The passwords are not in order. Default ordering is used. # Note: The passwords are not in order. Default ordering is used.
assert response.data == [ assert response.data == [
{'id': 1, username_field: 'userA'}, # PassB {"id": 1, username_field: "userA"}, # PassB
{'id': 2, username_field: 'userB'}, # PassC {"id": 2, username_field: "userB"}, # PassC
{'id': 3, username_field: 'userC'}, # PassA {"id": 3, username_field: "userC"}, # PassA
] ]

View File

@ -17,20 +17,18 @@ class FooView(APIView):
pass pass
urlpatterns = [ urlpatterns = [url(r"^$", FooView.as_view())]
url(r'^$', FooView.as_view())
]
@override_settings(ROOT_URLCONF='tests.test_generateschema') @override_settings(ROOT_URLCONF="tests.test_generateschema")
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed') @pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class GenerateSchemaTests(TestCase): class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema.""" """Tests for management command generateschema."""
def setUp(self): def setUp(self):
self.out = six.StringIO() self.out = six.StringIO()
@pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.') @pytest.mark.skipif(six.PY2, reason="PyYAML unicode output is malformed on PY2.")
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info: expected_out = """info:
description: Sample description description: Sample description
@ -44,45 +42,29 @@ class GenerateSchemaTests(TestCase):
servers: servers:
- url: http://api.sample.com/ - url: http://api.sample.com/
""" """
call_command('generateschema', call_command(
'--title=SampleAPI', "generateschema",
'--url=http://api.sample.com', "--title=SampleAPI",
'--description=Sample description', "--url=http://api.sample.com",
stdout=self.out) "--description=Sample description",
stdout=self.out,
)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self): def test_renders_openapi_json_schema(self):
expected_out = { expected_out = {
"openapi": "3.0.0", "openapi": "3.0.0",
"info": { "info": {"version": "", "title": "", "description": ""},
"version": "", "servers": [{"url": ""}],
"title": "", "paths": {"/": {"get": {"operationId": "list"}}},
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
} }
call_command('generateschema', call_command("generateschema", "--format=openapi-json", stdout=self.out)
'--format=openapi-json',
stdout=self.out)
out_json = json.loads(self.out.getvalue()) out_json = json.loads(self.out.getvalue())
self.assertDictEqual(out_json, expected_out) self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self): def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}""" expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema', call_command("generateschema", "--format=corejson", stdout=self.out)
'--format=corejson',
stdout=self.out)
self.assertIn(expected_out, self.out.getvalue()) self.assertIn(expected_out, self.out.getvalue())

View File

@ -11,10 +11,14 @@ from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import ( from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel, BasicModel,
UUIDForeignKeyTarget ForeignKeySource,
ForeignKeyTarget,
RESTFrameworkModel,
UUIDForeignKeyTarget,
) )
factory = APIRequestFactory() factory = APIRequestFactory()
@ -35,13 +39,13 @@ class Comment(RESTFrameworkModel):
class BasicSerializer(serializers.ModelSerializer): class BasicSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
fields = '__all__' fields = "__all__"
class ForeignKeySerializer(serializers.ModelSerializer): class ForeignKeySerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ForeignKeySource model = ForeignKeySource
fields = '__all__' fields = "__all__"
class SlugSerializer(serializers.ModelSerializer): class SlugSerializer(serializers.ModelSerializer):
@ -49,7 +53,7 @@ class SlugSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = SlugBasedModel model = SlugBasedModel
fields = ('text', 'slug') fields = ("text", "slug")
# Views # Views
@ -59,7 +63,7 @@ class RootView(generics.ListCreateAPIView):
class InstanceView(generics.RetrieveUpdateDestroyAPIView): class InstanceView(generics.RetrieveUpdateDestroyAPIView):
queryset = BasicModel.objects.exclude(text='filtered out') queryset = BasicModel.objects.exclude(text="filtered out")
serializer_class = BasicSerializer serializer_class = BasicSerializer
@ -72,9 +76,10 @@ class SlugBasedInstanceView(InstanceView):
""" """
A model with a slug-field. A model with a slug-field.
""" """
queryset = SlugBasedModel.objects.all() queryset = SlugBasedModel.objects.all()
serializer_class = SlugSerializer serializer_class = SlugSerializer
lookup_field = 'slug' lookup_field = "slug"
# Tests # Tests
@ -83,21 +88,18 @@ class TestRootView(TestCase):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ["foo", "bar", "baz"]
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = RootView.as_view() self.view = RootView.as_view()
def test_get_root_view(self): def test_get_root_view(self):
""" """
GET requests to ListCreateAPIView should return list of objects. GET requests to ListCreateAPIView should return list of objects.
""" """
request = factory.get('/') request = factory.get("/")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -107,7 +109,7 @@ class TestRootView(TestCase):
""" """
HEAD requests to ListCreateAPIView should return 200. HEAD requests to ListCreateAPIView should return 200.
""" """
request = factory.head('/') request = factory.head("/")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -116,21 +118,21 @@ class TestRootView(TestCase):
""" """
POST requests to ListCreateAPIView should create a new object. POST requests to ListCreateAPIView should create a new object.
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.post('/', data, format='json') request = factory.post("/", data, format="json")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4) created = self.objects.get(id=4)
assert created.text == 'foobar' assert created.text == "foobar"
def test_put_root_view(self): def test_put_root_view(self):
""" """
PUT requests to ListCreateAPIView should not be allowed PUT requests to ListCreateAPIView should not be allowed
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.put('/', data, format='json') request = factory.put("/", data, format="json")
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -140,7 +142,7 @@ class TestRootView(TestCase):
""" """
DELETE requests to ListCreateAPIView should not be allowed DELETE requests to ListCreateAPIView should not be allowed
""" """
request = factory.delete('/') request = factory.delete("/")
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -150,24 +152,24 @@ class TestRootView(TestCase):
""" """
POST requests to create a new object should not be able to set the id. POST requests to create a new object should not be able to set the id.
""" """
data = {'id': 999, 'text': 'foobar'} data = {"id": 999, "text": "foobar"}
request = factory.post('/', data, format='json') request = factory.post("/", data, format="json")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4) created = self.objects.get(id=4)
assert created.text == 'foobar' assert created.text == "foobar"
def test_post_error_root_view(self): def test_post_error_root_view(self):
""" """
POST requests to ListCreateAPIView in HTML should include a form error. POST requests to ListCreateAPIView in HTML should include a form error.
""" """
data = {'text': 'foobar' * 100} data = {"text": "foobar" * 100}
request = factory.post('/', data, HTTP_ACCEPT='text/html') request = factory.post("/", data, HTTP_ACCEPT="text/html")
response = self.view(request).render() response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8') assert expected_error in response.rendered_content.decode("utf-8")
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
@ -178,14 +180,11 @@ class TestInstanceView(TestCase):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz', 'filtered out'] items = ["foo", "bar", "baz", "filtered out"]
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out') self.objects = BasicModel.objects.exclude(text="filtered out")
self.data = [ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = InstanceView.as_view() self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view()
@ -193,7 +192,7 @@ class TestInstanceView(TestCase):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object. GET requests to RetrieveUpdateDestroyAPIView should return a single object.
""" """
request = factory.get('/1') request = factory.get("/1")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -203,8 +202,8 @@ class TestInstanceView(TestCase):
""" """
POST requests to RetrieveUpdateDestroyAPIView should not be allowed POST requests to RetrieveUpdateDestroyAPIView should not be allowed
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.post('/', data, format='json') request = factory.post("/", data, format="json")
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -214,38 +213,38 @@ class TestInstanceView(TestCase):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should update an object. PUT requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.put('/1', data, format='json') request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render() response = self.view(request, pk="1").render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'} assert dict(response.data) == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
assert updated.text == 'foobar' assert updated.text == "foobar"
def test_patch_instance_view(self): def test_patch_instance_view(self):
""" """
PATCH requests to RetrieveUpdateDestroyAPIView should update an object. PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.patch('/1', data, format='json') request = factory.patch("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
assert updated.text == 'foobar' assert updated.text == "foobar"
def test_delete_instance_view(self): def test_delete_instance_view(self):
""" """
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object. DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
""" """
request = factory.delete('/1') request = factory.delete("/1")
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('') assert response.content == six.b("")
ids = [obj.id for obj in self.objects.all()] ids = [obj.id for obj in self.objects.all()]
assert ids == [2, 3] assert ids == [2, 3]
@ -254,23 +253,23 @@ class TestInstanceView(TestCase):
GET requests with an incorrect pk type, should raise 404, not 500. GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890. Regression test for #890.
""" """
request = factory.get('/a') request = factory.get("/a")
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request, pk='a').render() response = self.view(request, pk="a").render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self): def test_put_cannot_set_id(self):
""" """
PUT requests to create a new object should not be able to set the id. PUT requests to create a new object should not be able to set the id.
""" """
data = {'id': 999, 'text': 'foobar'} data = {"id": 999, "text": "foobar"}
request = factory.put('/1', data, format='json') request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
assert updated.text == 'foobar' assert updated.text == "foobar"
def test_put_to_deleted_instance(self): def test_put_to_deleted_instance(self):
""" """
@ -278,8 +277,8 @@ class TestInstanceView(TestCase):
an object does not currently exist. an object does not currently exist.
""" """
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.put('/1', data, format='json') request = factory.put("/1", data, format="json")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@ -289,9 +288,9 @@ class TestInstanceView(TestCase):
PUT requests to an URL of instance which is filtered out should not be PUT requests to an URL of instance which is filtered out should not be
able to create new objects. able to create new objects.
""" """
data = {'text': 'foo'} data = {"text": "foo"}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk filtered_out_pk = BasicModel.objects.filter(text="filtered out")[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') request = factory.put("/{0}".format(filtered_out_pk), data, format="json")
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@ -299,8 +298,8 @@ class TestInstanceView(TestCase):
""" """
PATCH requests should not be able to create objects. PATCH requests should not be able to create objects.
""" """
data = {'text': 'foobar'} data = {"text": "foobar"}
request = factory.patch('/999', data, format='json') request = factory.patch("/999", data, format="json")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=999).render() response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
@ -310,11 +309,11 @@ class TestInstanceView(TestCase):
""" """
Incorrect PUT requests in HTML should include a form error. Incorrect PUT requests in HTML should include a form error.
""" """
data = {'text': 'foobar' * 100} data = {"text": "foobar" * 100}
request = factory.put('/', data, HTTP_ACCEPT='text/html') request = factory.put("/", data, HTTP_ACCEPT="text/html")
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8') assert expected_error in response.rendered_content.decode("utf-8")
class TestFKInstanceView(TestCase): class TestFKInstanceView(TestCase):
@ -322,17 +321,14 @@ class TestFKInstanceView(TestCase):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ["foo", "bar", "baz"]
for item in items: for item in items:
t = ForeignKeyTarget(name=item) t = ForeignKeyTarget(name=item)
t.save() t.save()
ForeignKeySource(name='source_' + item, target=t).save() ForeignKeySource(name="source_" + item, target=t).save()
self.objects = ForeignKeySource.objects self.objects = ForeignKeySource.objects
self.data = [ self.data = [{"id": obj.id, "name": obj.name} for obj in self.objects.all()]
{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()
]
self.view = FKInstanceView.as_view() self.view = FKInstanceView.as_view()
@ -346,23 +342,21 @@ class TestOverriddenGetObject(TestCase):
""" """
Create 3 BasicModel instances. Create 3 BasicModel instances.
""" """
items = ['foo', 'bar', 'baz'] items = ["foo", "bar", "baz"]
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
""" """
Example detail view for override of get_object(). Example detail view for override of get_object().
""" """
serializer_class = BasicSerializer serializer_class = BasicSerializer
def get_object(self): def get_object(self):
pk = int(self.kwargs['pk']) pk = int(self.kwargs["pk"])
return get_object_or_404(BasicModel.objects.all(), id=pk) return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view() self.view = OverriddenGetObjectView.as_view()
@ -371,7 +365,7 @@ class TestOverriddenGetObject(TestCase):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object. GET requests to RetrieveUpdateDestroyAPIView should return a single object.
""" """
request = factory.get('/1') request = factory.get("/1")
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -380,10 +374,11 @@ class TestOverriddenGetObject(TestCase):
# Regression test for #285 # Regression test for #285
class CommentSerializer(serializers.ModelSerializer): class CommentSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Comment model = Comment
exclude = ('created',) exclude = ("created",)
class CommentView(generics.ListCreateAPIView): class CommentView(generics.ListCreateAPIView):
@ -402,12 +397,12 @@ class TestCreateModelWithAutoNowAddField(TestCase):
https://github.com/encode/django-rest-framework/issues/285 https://github.com/encode/django-rest-framework/issues/285
""" """
data = {'email': 'foobar@example.com', 'content': 'foobar'} data = {"email": "foobar@example.com", "content": "foobar"}
request = factory.post('/', data, format='json') request = factory.post("/", data, format="json")
response = self.view(request).render() response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1) created = self.objects.get(id=1)
assert created.content == 'foobar' assert created.content == "foobar"
# Test for particularly ugly regression with m2m in browsable API # Test for particularly ugly regression with m2m in browsable API
@ -427,7 +422,7 @@ class ClassASerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ClassA model = ClassA
fields = '__all__' fields = "__all__"
class ExampleView(generics.ListCreateAPIView): class ExampleView(generics.ListCreateAPIView):
@ -440,7 +435,7 @@ class TestM2MBrowsableAPI(TestCase):
""" """
Test for particularly ugly regression with m2m in browsable API Test for particularly ugly regression with m2m in browsable API
""" """
request = factory.get('/', HTTP_ACCEPT='text/html') request = factory.get("/", HTTP_ACCEPT="text/html")
view = ExampleView().as_view() view = ExampleView().as_view()
response = view(request).render() response = view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
@ -448,12 +443,12 @@ class TestM2MBrowsableAPI(TestCase):
class InclusiveFilterBackend(object): class InclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
return queryset.filter(text='foo') return queryset.filter(text="foo")
class ExclusiveFilterBackend(object): class ExclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
return queryset.filter(text='other') return queryset.filter(text="other")
class TwoFieldModel(models.Model): class TwoFieldModel(models.Model):
@ -466,16 +461,20 @@ class DynamicSerializerView(generics.ListCreateAPIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer) renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get_serializer_class(self): def get_serializer_class(self):
if self.request.method == 'POST': if self.request.method == "POST":
class DynamicSerializer(serializers.ModelSerializer): class DynamicSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = TwoFieldModel model = TwoFieldModel
fields = ('field_b',) fields = ("field_b",)
else: else:
class DynamicSerializer(serializers.ModelSerializer): class DynamicSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = TwoFieldModel model = TwoFieldModel
fields = '__all__' fields = "__all__"
return DynamicSerializer return DynamicSerializer
@ -484,32 +483,29 @@ class TestFilterBackendAppliedToViews(TestCase):
""" """
Create 3 BasicModel instances to filter on. Create 3 BasicModel instances to filter on.
""" """
items = ['foo', 'bar', 'baz'] items = ["foo", "bar", "baz"]
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
GET requests to ListCreateAPIView should return filtered list. GET requests to ListCreateAPIView should return filtered list.
""" """
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,)) root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/') request = factory.get("/")
response = root_view(request).render() response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1 assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}] assert response.data == [{"id": 1, "text": "foo"}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
""" """
GET requests to ListCreateAPIView should return empty list when all models are filtered out. GET requests to ListCreateAPIView should return empty list when all models are filtered out.
""" """
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,)) root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/') request = factory.get("/")
response = root_view(request).render() response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == [] assert response.data == []
@ -519,31 +515,33 @@ class TestFilterBackendAppliedToViews(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
""" """
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1') request = factory.get("/1")
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'} assert response.data == {"detail": "Not found."}
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(
self
):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
""" """
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1') request = factory.get("/1")
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'} assert response.data == {"id": 1, "text": "foo"}
def test_dynamic_serializer_form_in_browsable_api(self): def test_dynamic_serializer_form_in_browsable_api(self):
""" """
GET requests to ListCreateAPIView should return filtered list. GET requests to ListCreateAPIView should return filtered list.
""" """
view = DynamicSerializerView.as_view() view = DynamicSerializerView.as_view()
request = factory.get('/') request = factory.get("/")
response = view(request).render() response = view(request).render()
content = response.content.decode('utf8') content = response.content.decode("utf8")
assert 'field_b' in content assert "field_b" in content
assert 'field_a' not in content assert "field_a" not in content
class TestGuardedQueryset(TestCase): class TestGuardedQueryset(TestCase):
@ -555,21 +553,21 @@ class TestGuardedQueryset(TestCase):
return Response(list(self.queryset)) return Response(list(self.queryset))
view = QuerysetAccessError.as_view() view = QuerysetAccessError.as_view()
request = factory.get('/') request = factory.get("/")
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
view(request).render() view(request).render()
class ApiViewsTests(TestCase): class ApiViewsTests(TestCase):
def test_create_api_view_post(self): def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView): class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockCreateApiView() view = MockCreateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.post('test request', 'test arg', test_kwarg='test') view.post("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -578,9 +576,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockDestroyApiView() view = MockDestroyApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.delete('test request', 'test arg', test_kwarg='test') view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -589,9 +588,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockUpdateApiView() view = MockUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.patch('test request', 'test arg', test_kwarg='test') view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -600,9 +600,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.get('test request', 'test arg', test_kwarg='test') view.get("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -611,9 +612,10 @@ class ApiViewsTests(TestCase):
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.put('test request', 'test arg', test_kwarg='test') view.put("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -622,9 +624,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView() view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.patch('test request', 'test arg', test_kwarg='test') view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -633,9 +636,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.get('test request', 'test arg', test_kwarg='test') view.get("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -644,9 +648,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
self.called = True self.called = True
self.call_args = (request, args, kwargs) self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView() view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'}) data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.delete('test request', 'test arg', test_kwarg='test') view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
@ -654,14 +659,12 @@ class ApiViewsTests(TestCase):
class GetObjectOr404Tests(TestCase): class GetObjectOr404Tests(TestCase):
def setUp(self): def setUp(self):
super(GetObjectOr404Tests, self).setUp() super(GetObjectOr404Tests, self).setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar') self.uuid_object = UUIDForeignKeyTarget.objects.create(name="bar")
def test_get_object_or_404_with_valid_uuid(self): def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404( obj = generics.get_object_or_404(UUIDForeignKeyTarget, pk=self.uuid_object.pk)
UUIDForeignKeyTarget, pk=self.uuid_object.pk
)
assert obj == self.uuid_object assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self): def test_get_object_or_404_with_invalid_string_for_uuid(self):
with pytest.raises(Http404): with pytest.raises(Http404):
generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid') generics.get_object_or_404(UUIDForeignKeyTarget, pk="not-a-uuid")

View File

@ -15,40 +15,41 @@ from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response from rest_framework.response import Response
@api_view(('GET',)) @api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,)) @renderer_classes((TemplateHTMLRenderer,))
def example(request): def example(request):
""" """
A view that can returns an HTML representation. A view that can returns an HTML representation.
""" """
data = {'object': 'foobar'} data = {"object": "foobar"}
return Response(data, template_name='example.html') return Response(data, template_name="example.html")
@api_view(('GET',)) @api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,)) @renderer_classes((TemplateHTMLRenderer,))
def permission_denied(request): def permission_denied(request):
raise PermissionDenied() raise PermissionDenied()
@api_view(('GET',)) @api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,)) @renderer_classes((TemplateHTMLRenderer,))
def not_found(request): def not_found(request):
raise Http404() raise Http404()
urlpatterns = [ urlpatterns = [
url(r'^$', example), url(r"^$", example),
url(r'^permission_denied$', permission_denied), url(r"^permission_denied$", permission_denied),
url(r'^not_found$', not_found), url(r"^not_found$", not_found),
] ]
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererTests(TestCase): class TemplateHTMLRendererTests(TestCase):
def setUp(self): def setUp(self):
class MockResponse(object): class MockResponse(object):
template_name = None template_name = None
self.mock_response = MockResponse() self.mock_response = MockResponse()
self._monkey_patch_get_template() self._monkey_patch_get_template()
@ -59,13 +60,13 @@ class TemplateHTMLRendererTests(TestCase):
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == "example.html":
return engines['django'].from_string("example: {{ object }}") return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None): def select_template(template_name_list, dirs=None, using=None):
if template_name_list == ['example.html']: if template_name_list == ["example.html"]:
return engines['django'].from_string("example: {{ object }}") return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0]) raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
@ -78,29 +79,29 @@ class TemplateHTMLRendererTests(TestCase):
django.template.loader.get_template = self.get_template django.template.loader.get_template = self.get_template
def test_simple_html_view(self): def test_simple_html_view(self):
response = self.client.get('/') response = self.client.get("/")
self.assertContains(response, "example: foobar") self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_not_found_html_view(self): def test_not_found_html_view(self):
response = self.client.get('/not_found') response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404 Not Found")) self.assertEqual(response.content, six.b("404 Not Found"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view(self): def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied') response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403 Forbidden")) self.assertEqual(response.content, six.b("403 Forbidden"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
# 2 tests below are based on order of if statements in corresponding method # 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer # of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self): def test_get_template_names_returns_own_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
renderer.template_name = 'test_template' renderer.template_name = "test_template"
template_name = renderer.get_template_names(self.mock_response, view={}) template_name = renderer.get_template_names(self.mock_response, view={})
assert template_name == ['test_template'] assert template_name == ["test_template"]
def test_get_template_names_returns_view_template_name(self): def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
@ -110,18 +111,16 @@ class TemplateHTMLRendererTests(TestCase):
class MockView(object): class MockView(object):
def get_template_names(self): def get_template_names(self):
return ['template from get_template_names method'] return ["template from get_template_names method"]
class MockView2(object): class MockView2(object):
template_name = 'template from template_name attribute' template_name = "template from template_name attribute"
template_name = renderer.get_template_names(self.mock_response, template_name = renderer.get_template_names(self.mock_response, MockView())
MockView()) assert template_name == ["template from get_template_names method"]
assert template_name == ['template from get_template_names method']
template_name = renderer.get_template_names(self.mock_response, template_name = renderer.get_template_names(self.mock_response, MockView2())
MockView2()) assert template_name == ["template from template_name attribute"]
assert template_name == ['template from template_name attribute']
def test_get_template_names_raises_error_if_no_template_found(self): def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer() renderer = TemplateHTMLRenderer()
@ -129,7 +128,7 @@ class TemplateHTMLRendererTests(TestCase):
renderer.get_template_names(self.mock_response, view=object()) renderer.get_template_names(self.mock_response, view=object())
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer') @override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererExceptionTests(TestCase): class TemplateHTMLRendererExceptionTests(TestCase):
def setUp(self): def setUp(self):
""" """
@ -138,10 +137,10 @@ class TemplateHTMLRendererExceptionTests(TestCase):
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name): def get_template(template_name):
if template_name == '404.html': if template_name == "404.html":
return engines['django'].from_string("404: {{ detail }}") return engines["django"].from_string("404: {{ detail }}")
if template_name == '403.html': if template_name == "403.html":
return engines['django'].from_string("403: {{ detail }}") return engines["django"].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
@ -153,15 +152,18 @@ class TemplateHTMLRendererExceptionTests(TestCase):
django.template.loader.get_template = self.get_template django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertTrue(response.content in ( self.assertTrue(
six.b("404: Not found"), six.b("404 Not Found"))) response.content in (six.b("404: Not found"), six.b("404 Not Found"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') )
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertTrue(response.content in ( self.assertTrue(
six.b("403: Permission denied"), six.b("403 Forbidden"))) response.content
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') in (six.b("403: Permission denied"), six.b("403 Forbidden"))
)
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")

View File

@ -6,6 +6,7 @@ from rest_framework import serializers
from rest_framework.renderers import JSONRenderer from rest_framework.renderers import JSONRenderer
from rest_framework.templatetags.rest_framework import format_value from rest_framework.templatetags.rest_framework import format_value
str_called = False str_called = False
@ -15,35 +16,33 @@ class Example(models.Model):
def __str__(self): def __str__(self):
global str_called global str_called
str_called = True str_called = True
return 'An example' return "An example"
class ExampleSerializer(serializers.HyperlinkedModelSerializer): class ExampleSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
model = Example model = Example
fields = ('url', 'id', 'text') fields = ("url", "id", "text")
def dummy_view(request): def dummy_view(request):
pass pass
urlpatterns = [ urlpatterns = [url(r"^example/(?P<pk>[0-9]+)/$", dummy_view, name="example-detail")]
url(r'^example/(?P<pk>[0-9]+)/$', dummy_view, name='example-detail'),
]
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks') @override_settings(ROOT_URLCONF="tests.test_lazy_hyperlinks")
class TestLazyHyperlinkNames(TestCase): class TestLazyHyperlinkNames(TestCase):
def setUp(self): def setUp(self):
self.example = Example.objects.create(text='foo') self.example = Example.objects.create(text="foo")
def test_lazy_hyperlink_names(self): def test_lazy_hyperlink_names(self):
global str_called global str_called
context = {'request': None} context = {"request": None}
serializer = ExampleSerializer(self.example, context=context) serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data) JSONRenderer().render(serializer.data)
assert not str_called assert not str_called
hyperlink_string = format_value(serializer.data['url']) hyperlink_string = format_value(serializer.data["url"])
assert hyperlink_string == '<a href=/example/1/>An example</a>' assert hyperlink_string == "<a href=/example/1/>An example</a>"
assert str_called assert str_called

View File

@ -5,19 +5,17 @@ from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import ( from rest_framework import exceptions, metadata, serializers, status, versioning, views
exceptions, metadata, serializers, status, versioning, views
)
from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from .models import BasicModel from .models import BasicModel
request = APIRequestFactory().options('/')
request = APIRequestFactory().options("/")
class TestMetadata: class TestMetadata:
def test_determine_metadata_abstract_method_raises_proper_error(self): def test_determine_metadata_abstract_method_raises_proper_error(self):
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
metadata.BaseMetadata().determine_metadata(None, None) metadata.BaseMetadata().determine_metadata(None, None)
@ -26,24 +24,23 @@ class TestMetadata:
""" """
OPTIONS requests to views should return a valid 200 response. OPTIONS requests to views should return a valid 200 response.
""" """
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
pass pass
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
expected = { expected = {
'name': 'Example', "name": "Example",
'description': 'Example view.', "description": "Example view.",
'renders': [ "renders": ["application/json", "text/html"],
'application/json', "parses": [
'text/html' "application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
], ],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
]
} }
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
@ -53,41 +50,40 @@ class TestMetadata:
OPTIONS requests to views where `metadata_class = None` should raise OPTIONS requests to views where `metadata_class = None` should raise
a MethodNotAllowed exception, which will result in an HTTP 405 response. a MethodNotAllowed exception, which will result in an HTTP 405 response.
""" """
class ExampleView(views.APIView): class ExampleView(views.APIView):
metadata_class = None metadata_class = None
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {'detail': 'Method "OPTIONS" not allowed.'} assert response.data == {"detail": 'Method "OPTIONS" not allowed.'}
def test_actions(self): def test_actions(self):
""" """
On generic views OPTIONS should return an 'actions' key with metadata On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. on the fields that may be supplied to PUT and POST requests.
""" """
class NestedField(serializers.Serializer): class NestedField(serializers.Serializer):
a = serializers.IntegerField() a = serializers.IntegerField()
b = serializers.IntegerField() b = serializers.IntegerField()
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField( integer_field = serializers.IntegerField(min_value=1, max_value=1000)
min_value=1, max_value=1000
)
char_field = serializers.CharField( char_field = serializers.CharField(
required=False, min_length=3, max_length=40 required=False, min_length=3, max_length=40
) )
list_field = serializers.ListField( list_field = serializers.ListField(
child=serializers.ListField( child=serializers.ListField(child=serializers.IntegerField())
child=serializers.IntegerField()
)
) )
nested_field = NestedField() nested_field = NestedField()
uuid_field = serializers.UUIDField(label="UUID field") uuid_field = serializers.UUIDField(label="UUID field")
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -97,91 +93,87 @@ class TestMetadata:
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
expected = { expected = {
'name': 'Example', "name": "Example",
'description': 'Example view.', "description": "Example view.",
'renders': [ "renders": ["application/json", "text/html"],
'application/json', "parses": [
'text/html' "application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
], ],
'parses': [ "actions": {
'application/json', "POST": {
'application/x-www-form-urlencoded', "choice_field": {
'multipart/form-data' "type": "choice",
], "required": True,
'actions': { "read_only": False,
'POST': { "label": "Choice field",
'choice_field': { "choices": [
'type': 'choice', {"display_name": "red", "value": "red"},
'required': True, {"display_name": "green", "value": "green"},
'read_only': False, {"display_name": "blue", "value": "blue"},
'label': 'Choice field', ],
'choices': [
{'display_name': 'red', 'value': 'red'},
{'display_name': 'green', 'value': 'green'},
{'display_name': 'blue', 'value': 'blue'}
]
}, },
'integer_field': { "integer_field": {
'type': 'integer', "type": "integer",
'required': True, "required": True,
'read_only': False, "read_only": False,
'label': 'Integer field', "label": "Integer field",
'min_value': 1, "min_value": 1,
'max_value': 1000, "max_value": 1000,
}, },
'char_field': { "char_field": {
'type': 'string', "type": "string",
'required': False, "required": False,
'read_only': False, "read_only": False,
'label': 'Char field', "label": "Char field",
'min_length': 3, "min_length": 3,
'max_length': 40 "max_length": 40,
}, },
'list_field': { "list_field": {
'type': 'list', "type": "list",
'required': True, "required": True,
'read_only': False, "read_only": False,
'label': 'List field', "label": "List field",
'child': { "child": {
'type': 'list', "type": "list",
'required': True, "required": True,
'read_only': False, "read_only": False,
'child': { "child": {
'type': 'integer', "type": "integer",
'required': True, "required": True,
'read_only': False "read_only": False,
}
}
},
'nested_field': {
'type': 'nested object',
'required': True,
'read_only': False,
'label': 'Nested field',
'children': {
'a': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'A'
}, },
'b': { },
'type': 'integer',
'required': True,
'read_only': False,
'label': 'B'
}
}
}, },
'uuid_field': { "nested_field": {
"type": "nested object",
"required": True,
"read_only": False,
"label": "Nested field",
"children": {
"a": {
"type": "integer",
"required": True,
"read_only": False,
"label": "A",
},
"b": {
"type": "integer",
"required": True,
"read_only": False,
"label": "B",
},
},
},
"uuid_field": {
"type": "string", "type": "string",
"required": True, "required": True,
"read_only": False, "read_only": False,
"label": "UUID field", "label": "UUID field",
}, },
} }
} },
} }
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
@ -191,13 +183,15 @@ class TestMetadata:
If a user does not have global permissions on an action, then any If a user does not have global permissions on an action, then any
metadata associated with it should not be included in OPTION responses. metadata associated with it should not be included in OPTION responses.
""" """
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10) integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False) char_field = serializers.CharField(required=False)
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -208,26 +202,28 @@ class TestMetadata:
return ExampleSerializer() return ExampleSerializer()
def check_permissions(self, request): def check_permissions(self, request):
if request.method == 'POST': if request.method == "POST":
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied()
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions']) == ['PUT'] assert list(response.data["actions"]) == ["PUT"]
def test_object_permissions(self): def test_object_permissions(self):
""" """
If a user does not have object permissions on an action, then any If a user does not have object permissions on an action, then any
metadata associated with it should not be included in OPTION responses. metadata associated with it should not be included in OPTION responses.
""" """
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10) integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False) char_field = serializers.CharField(required=False)
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -238,13 +234,13 @@ class TestMetadata:
return ExampleSerializer() return ExampleSerializer()
def get_object(self): def get_object(self):
if self.request.method == 'PUT': if self.request.method == "PUT":
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied()
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['POST'] assert list(response.data["actions"].keys()) == ["POST"]
def test_bug_2455_clone_request(self): def test_bug_2455_clone_request(self):
class ExampleView(views.APIView): class ExampleView(views.APIView):
@ -254,7 +250,7 @@ class TestMetadata:
pass pass
def get_serializer(self): def get_serializer(self):
assert hasattr(self.request, 'version') assert hasattr(self.request, "version")
return serializers.Serializer() return serializers.Serializer()
view = ExampleView.as_view() view = ExampleView.as_view()
@ -268,7 +264,7 @@ class TestMetadata:
pass pass
def get_serializer(self): def get_serializer(self):
assert hasattr(self.request, 'versioning_scheme') assert hasattr(self.request, "versioning_scheme")
return serializers.Serializer() return serializers.Serializer()
scheme = versioning.QueryParameterVersioning scheme = versioning.QueryParameterVersioning
@ -279,12 +275,14 @@ class TestMetadata:
""" """
HiddenField shouldn't show up in SimpleMetadata at all. HiddenField shouldn't show up in SimpleMetadata at all.
""" """
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10) integer_field = serializers.IntegerField(max_value=10)
hidden_field = serializers.HiddenField(default=1) hidden_field = serializers.HiddenField(default=1)
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -294,9 +292,11 @@ class TestMetadata:
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert set(response.data['actions']['POST'].keys()) == {'integer_field'} assert set(response.data["actions"]["POST"].keys()) == {"integer_field"}
def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(self): def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(
self
):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10) integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False) char_field = serializers.CharField(required=False)
@ -307,14 +307,16 @@ class TestMetadata:
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
child_serializer = ExampleSerializer() child_serializer = ExampleSerializer()
list_serializer = ExampleListSerializer(child=child_serializer) list_serializer = ExampleListSerializer(child=child_serializer)
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer) assert options.get_serializer_info(
list_serializer
) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase): class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self): def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField()) field_info = options.get_field_info(serializers.NullBooleanField())
assert field_info['type'] == 'boolean' assert field_info["type"] == "boolean"
def test_related_field_choices(self): def test_related_field_choices(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
@ -323,7 +325,7 @@ class TestSimpleMetadataFieldInfo(TestCase):
field_info = options.get_field_info( field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all()) serializers.RelatedField(queryset=BasicModel.objects.all())
) )
assert 'choices' not in field_info assert "choices" not in field_info
class TestModelSerializerMetadata(TestCase): class TestModelSerializerMetadata(TestCase):
@ -333,9 +335,12 @@ class TestModelSerializerMetadata(TestCase):
on the fields that may be supplied to PUT and POST requests. It should on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present not fail when a read_only PrimaryKeyRelatedField is present
""" """
class Parent(models.Model): class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)]) integer_field = models.IntegerField(
children = models.ManyToManyField('Child') validators=[MinValueValidator(1), MaxValueValidator(1000)]
)
children = models.ManyToManyField("Child")
name = models.CharField(max_length=100, blank=True, null=True) name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model): class Child(models.Model):
@ -346,10 +351,11 @@ class TestModelSerializerMetadata(TestCase):
class Meta: class Meta:
model = Parent model = Parent
fields = '__all__' fields = "__all__"
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -359,48 +365,45 @@ class TestModelSerializerMetadata(TestCase):
view = ExampleView.as_view() view = ExampleView.as_view()
response = view(request=request) response = view(request=request)
expected = { expected = {
'name': 'Example', "name": "Example",
'description': 'Example view.', "description": "Example view.",
'renders': [ "renders": ["application/json", "text/html"],
'application/json', "parses": [
'text/html' "application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
], ],
'parses': [ "actions": {
'application/json', "POST": {
'application/x-www-form-urlencoded', "id": {
'multipart/form-data' "type": "integer",
], "required": False,
'actions': { "read_only": True,
'POST': { "label": "ID",
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
}, },
'children': { "children": {
'type': 'field', "type": "field",
'required': False, "required": False,
'read_only': True, "read_only": True,
'label': 'Children' "label": "Children",
}, },
'integer_field': { "integer_field": {
'type': 'integer', "type": "integer",
'required': True, "required": True,
'read_only': False, "read_only": False,
'label': 'Integer field', "label": "Integer field",
'min_value': 1, "min_value": 1,
'max_value': 1000 "max_value": 1000,
},
"name": {
"type": "string",
"required": False,
"read_only": False,
"label": "Name",
"max_length": 100,
}, },
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
} }
} },
} }
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK

View File

@ -17,8 +17,8 @@ class PostView(APIView):
urlpatterns = [ urlpatterns = [
url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))), url(r"^auth$", APIView.as_view(authentication_classes=(TokenAuthentication,))),
url(r'^post$', PostView.as_view()), url(r"^post$", PostView.as_view()),
] ]
@ -28,8 +28,8 @@ class RequestUserMiddleware(object):
def __call__(self, request): def __call__(self, request):
response = self.get_response(request) response = self.get_response(request)
assert hasattr(request, 'user'), '`user` is not set on request' assert hasattr(request, "user"), "`user` is not set on request"
assert request.user.is_authenticated, '`user` is not authenticated' assert request.user.is_authenticated, "`user` is not authenticated"
return response return response
@ -49,28 +49,27 @@ class RequestPOSTMiddleware(object):
# Ensure request.POST is set as appropriate # Ensure request.POST is set as appropriate
if is_form_media_type(request.content_type): if is_form_media_type(request.content_type):
assert request.POST == {'foo': ['bar']} assert request.POST == {"foo": ["bar"]}
else: else:
assert request.POST == {} assert request.POST == {}
return response return response
@override_settings(ROOT_URLCONF='tests.test_middleware') @override_settings(ROOT_URLCONF="tests.test_middleware")
class TestMiddleware(APITestCase): class TestMiddleware(APITestCase):
@override_settings(MIDDLEWARE=("tests.test_middleware.RequestUserMiddleware",))
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',))
def test_middleware_can_access_user_when_processing_response(self): def test_middleware_can_access_user_when_processing_response(self):
user = User.objects.create_user('john', 'john@example.com', 'password') user = User.objects.create_user("john", "john@example.com", "password")
key = 'abcd1234' key = "abcd1234"
Token.objects.create(key=key, user=user) Token.objects.create(key=key, user=user)
self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) self.client.get("/auth", HTTP_AUTHORIZATION="Token %s" % key)
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',)) @override_settings(MIDDLEWARE=("tests.test_middleware.RequestPOSTMiddleware",))
def test_middleware_can_access_request_post_when_processing_response(self): def test_middleware_can_access_request_post_when_processing_response(self):
response = self.client.post('/post', {'foo': 'bar'}) response = self.client.post("/post", {"foo": "bar"})
assert response.status_code == 200 assert response.status_code == 200
response = self.client.post('/post', {'foo': 'bar'}, format='json') response = self.client.post("/post", {"foo": "bar"}, format="json")
assert response.status_code == 200 assert response.status_code == 200

File diff suppressed because it is too large Load Diff

View File

@ -25,45 +25,41 @@ class AssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer): class DerivedModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ChildModel model = ChildModel
fields = '__all__' fields = "__all__"
class AssociatedModelSerializer(serializers.ModelSerializer): class AssociatedModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = AssociatedModel model = AssociatedModel
fields = '__all__' fields = "__all__"
# Tests # Tests
class InheritedModelSerializationTests(TestCase): class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id'} assert set(serializer.data) == {"name1", "name2", "id"}
def test_onetoone_primary_key_model_fields_as_expected(self): def test_onetoone_primary_key_model_fields_as_expected(self):
""" """
Assert that a model with a onetoone field that is the primary key is Assert that a model with a onetoone field that is the primary key is
not treated like a derived model not treated like a derived model
""" """
parent = ParentModel.objects.create(name1='parent name') parent = ParentModel.objects.create(name1="parent name")
associate = AssociatedModel.objects.create(name='hello', ref=parent) associate = AssociatedModel.objects.create(name="hello", ref=parent)
serializer = AssociatedModelSerializer(associate) serializer = AssociatedModelSerializer(associate)
assert set(serializer.data) == {'name', 'ref'} assert set(serializer.data) == {"name", "ref"}
def test_data_is_valid_without_parent_ptr(self): def test_data_is_valid_without_parent_ptr(self):
""" """
Assert that the pointer to the parent table is not a required field Assert that the pointer to the parent table is not a required field
for input data for input data
""" """
data = { data = {"name1": "parent name", "name2": "child name"}
'name1': 'parent name',
'name2': 'child name',
}
serializer = DerivedModelSerializer(data=data) serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True assert serializer.is_valid() is True

View File

@ -4,32 +4,31 @@ import pytest
from django.http import Http404 from django.http import Http404
from django.test import TestCase from django.test import TestCase
from rest_framework.negotiation import ( from rest_framework.negotiation import BaseContentNegotiation, DefaultContentNegotiation
BaseContentNegotiation, DefaultContentNegotiation
)
from rest_framework.renderers import BaseRenderer from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.utils.mediatypes import _MediaType from rest_framework.utils.mediatypes import _MediaType
factory = APIRequestFactory() factory = APIRequestFactory()
class MockOpenAPIRenderer(BaseRenderer): class MockOpenAPIRenderer(BaseRenderer):
media_type = 'application/openapi+json;version=2.0' media_type = "application/openapi+json;version=2.0"
format = 'swagger' format = "swagger"
class MockJSONRenderer(BaseRenderer): class MockJSONRenderer(BaseRenderer):
media_type = 'application/json' media_type = "application/json"
class MockHTMLRenderer(BaseRenderer): class MockHTMLRenderer(BaseRenderer):
media_type = 'text/html' media_type = "text/html"
class NoCharsetSpecifiedRenderer(BaseRenderer): class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media' media_type = "my/media"
class TestAcceptedMediaType(TestCase): class TestAcceptedMediaType(TestCase):
@ -41,54 +40,56 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers) return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self): def test_client_without_accept_use_renderer(self):
request = Request(factory.get('/')) request = Request(factory.get("/"))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json' assert accepted_media_type == "application/json"
def test_client_underspecifies_accept_use_renderer(self): def test_client_underspecifies_accept_use_renderer(self):
request = Request(factory.get('/', HTTP_ACCEPT='*/*')) request = Request(factory.get("/", HTTP_ACCEPT="*/*"))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json' assert accepted_media_type == "application/json"
def test_client_overspecifies_accept_use_client(self): def test_client_overspecifies_accept_use_client(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8')) request = Request(factory.get("/", HTTP_ACCEPT="application/json; indent=8"))
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json; indent=8' assert accepted_media_type == "application/json; indent=8"
def test_client_specifies_parameter(self): def test_client_specifies_parameter(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0')) request = Request(
factory.get("/", HTTP_ACCEPT="application/openapi+json;version=2.0")
)
accepted_renderer, accepted_media_type = self.select_renderer(request) accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/openapi+json;version=2.0' assert accepted_media_type == "application/openapi+json;version=2.0"
assert accepted_renderer.format == 'swagger' assert accepted_renderer.format == "swagger"
def test_match_is_false_if_main_types_not_match(self): def test_match_is_false_if_main_types_not_match(self):
mediatype = _MediaType('test_1') mediatype = _MediaType("test_1")
anoter_mediatype = _MediaType('test_2') anoter_mediatype = _MediaType("test_2")
assert mediatype.match(anoter_mediatype) is False assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self): def test_mediatype_match_is_false_if_keys_not_match(self):
mediatype = _MediaType(';test_param=foo') mediatype = _MediaType(";test_param=foo")
another_mediatype = _MediaType(';test_param=bar') another_mediatype = _MediaType(";test_param=bar")
assert mediatype.match(another_mediatype) is False assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self): def test_mediatype_precedence_with_wildcard_subtype(self):
mediatype = _MediaType('test/*') mediatype = _MediaType("test/*")
assert mediatype.precedence == 1 assert mediatype.precedence == 1
def test_mediatype_string_representation(self): def test_mediatype_string_representation(self):
mediatype = _MediaType('test/*; foo=bar') mediatype = _MediaType("test/*; foo=bar")
assert str(mediatype) == 'test/*; foo=bar' assert str(mediatype) == "test/*; foo=bar"
def test_raise_error_if_no_suitable_renderers_found(self): def test_raise_error_if_no_suitable_renderers_found(self):
class MockRenderer(object): class MockRenderer(object):
format = 'xml' format = "xml"
renderers = [MockRenderer()] renderers = [MockRenderer()]
with pytest.raises(Http404): with pytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json') self.negotiator.filter_renderers(renderers, format="json")
class BaseContentNegotiationTests(TestCase): class BaseContentNegotiationTests(TestCase):
def setUp(self): def setUp(self):
self.negotiator = BaseContentNegotiation() self.negotiator = BaseContentNegotiation()

View File

@ -3,9 +3,9 @@ from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
# Models
from rest_framework import serializers from rest_framework import serializers
from tests.models import RESTFrameworkModel from tests.models import RESTFrameworkModel
# Models
from tests.test_multitable_inheritance import ChildModel from tests.test_multitable_inheritance import ChildModel
@ -19,25 +19,24 @@ class ChildAssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer): class DerivedModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ChildModel model = ChildModel
fields = ['id', 'name1', 'name2', 'childassociatedmodel'] fields = ["id", "name1", "name2", "childassociatedmodel"]
class ChildAssociatedModelSerializer(serializers.ModelSerializer): class ChildAssociatedModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ChildAssociatedModel model = ChildAssociatedModel
fields = ['id', 'child_name'] fields = ["id", "child_name"]
# Tests # Tests
class InheritedModelSerializationTests(TestCase): class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields
serialized fields serialized fields
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data), self.assertEqual(
{'name1', 'name2', 'id', 'childassociatedmodel'}) set(serializer.data), {"name1", "name2", "id", "childassociatedmodel"}
)

View File

@ -8,12 +8,18 @@ from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework import ( from rest_framework import (
exceptions, filters, generics, pagination, serializers, status exceptions,
filters,
generics,
pagination,
serializers,
status,
) )
from rest_framework.pagination import PAGE_BREAK, PageLink from rest_framework.pagination import PAGE_BREAK, PageLink
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
@ -33,39 +39,39 @@ class TestPaginationIntegration:
class BasicPagination(pagination.PageNumberPagination): class BasicPagination(pagination.PageNumberPagination):
page_size = 5 page_size = 5
page_size_query_param = 'page_size' page_size_query_param = "page_size"
max_page_size = 20 max_page_size = 20
self.view = generics.ListAPIView.as_view( self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer, serializer_class=PassThroughSerializer,
queryset=range(1, 101), queryset=range(1, 101),
filter_backends=[EvenItemsOnly], filter_backends=[EvenItemsOnly],
pagination_class=BasicPagination pagination_class=BasicPagination,
) )
def test_filtered_items_are_paginated(self): def test_filtered_items_are_paginated(self):
request = factory.get('/', {'page': 2}) request = factory.get("/", {"page": 2})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [12, 14, 16, 18, 20], "results": [12, 14, 16, 18, 20],
'previous': 'http://testserver/', "previous": "http://testserver/",
'next': 'http://testserver/?page=3', "next": "http://testserver/?page=3",
'count': 50 "count": 50,
} }
def test_setting_page_size(self): def test_setting_page_size(self):
""" """
When 'paginate_by_param' is set, the client may choose a page size. When 'paginate_by_param' is set, the client may choose a page size.
""" """
request = factory.get('/', {'page_size': 10}) request = factory.get("/", {"page_size": 10})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20], "results": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
'previous': None, "previous": None,
'next': 'http://testserver/?page=2&page_size=10', "next": "http://testserver/?page=2&page_size=10",
'count': 50 "count": 50,
} }
def test_setting_page_size_over_maximum(self): def test_setting_page_size_over_maximum(self):
@ -73,70 +79,84 @@ class TestPaginationIntegration:
When page_size parameter exceeds maximum allowable, When page_size parameter exceeds maximum allowable,
then it should be capped to the maximum. then it should be capped to the maximum.
""" """
request = factory.get('/', {'page_size': 1000}) request = factory.get("/", {"page_size": 1000})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [ "results": [
2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 2,
22, 24, 26, 28, 30, 32, 34, 36, 38, 40 4,
6,
8,
10,
12,
14,
16,
18,
20,
22,
24,
26,
28,
30,
32,
34,
36,
38,
40,
], ],
'previous': None, "previous": None,
'next': 'http://testserver/?page=2&page_size=1000', "next": "http://testserver/?page=2&page_size=1000",
'count': 50 "count": 50,
} }
def test_setting_page_size_to_zero(self): def test_setting_page_size_to_zero(self):
""" """
When page_size parameter is invalid it should return to the default. When page_size parameter is invalid it should return to the default.
""" """
request = factory.get('/', {'page_size': 0}) request = factory.get("/", {"page_size": 0})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [2, 4, 6, 8, 10], "results": [2, 4, 6, 8, 10],
'previous': None, "previous": None,
'next': 'http://testserver/?page=2&page_size=0', "next": "http://testserver/?page=2&page_size=0",
'count': 50 "count": 50,
} }
def test_additional_query_params_are_preserved(self): def test_additional_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': 'even'}) request = factory.get("/", {"page": 2, "filter": "even"})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [12, 14, 16, 18, 20], "results": [12, 14, 16, 18, 20],
'previous': 'http://testserver/?filter=even', "previous": "http://testserver/?filter=even",
'next': 'http://testserver/?filter=even&page=3', "next": "http://testserver/?filter=even&page=3",
'count': 50 "count": 50,
} }
def test_empty_query_params_are_preserved(self): def test_empty_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': ''}) request = factory.get("/", {"page": 2, "filter": ""})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == { assert response.data == {
'results': [12, 14, 16, 18, 20], "results": [12, 14, 16, 18, 20],
'previous': 'http://testserver/?filter=', "previous": "http://testserver/?filter=",
'next': 'http://testserver/?filter=&page=3', "next": "http://testserver/?filter=&page=3",
'count': 50 "count": 50,
} }
def test_404_not_found_for_zero_page(self): def test_404_not_found_for_zero_page(self):
request = factory.get('/', {'page': '0'}) request = factory.get("/", {"page": "0"})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == { assert response.data == {"detail": "Invalid page."}
'detail': 'Invalid page.'
}
def test_404_not_found_for_invalid_page(self): def test_404_not_found_for_invalid_page(self):
request = factory.get('/', {'page': 'invalid'}) request = factory.get("/", {"page": "invalid"})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == { assert response.data == {"detail": "Invalid page."}
'detail': 'Invalid page.'
}
class TestPaginationDisabledIntegration: class TestPaginationDisabledIntegration:
@ -152,11 +172,11 @@ class TestPaginationDisabledIntegration:
self.view = generics.ListAPIView.as_view( self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer, serializer_class=PassThroughSerializer,
queryset=range(1, 101), queryset=range(1, 101),
pagination_class=None pagination_class=None,
) )
def test_unpaginated_list(self): def test_unpaginated_list(self):
request = factory.get('/', {'page': 2}) request = factory.get("/", {"page": 2})
response = self.view(request) response = self.view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == list(range(1, 101)) assert response.data == list(range(1, 101))
@ -185,81 +205,81 @@ class TestPageNumberPagination:
return self.pagination.get_html_context() return self.pagination.get_html_context()
def test_no_page_number(self): def test_no_page_number(self):
request = Request(factory.get('/')) request = Request(factory.get("/"))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5] assert queryset == [1, 2, 3, 4, 5]
assert content == { assert content == {
'results': [1, 2, 3, 4, 5], "results": [1, 2, 3, 4, 5],
'previous': None, "previous": None,
'next': 'http://testserver/?page=2', "next": "http://testserver/?page=2",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': None, "previous_url": None,
'next_url': 'http://testserver/?page=2', "next_url": "http://testserver/?page=2",
'page_links': [ "page_links": [
PageLink('http://testserver/', 1, True, False), PageLink("http://testserver/", 1, True, False),
PageLink('http://testserver/?page=2', 2, False, False), PageLink("http://testserver/?page=2", 2, False, False),
PageLink('http://testserver/?page=3', 3, False, False), PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?page=20', 20, False, False), PageLink("http://testserver/?page=20", 20, False, False),
] ],
} }
assert self.pagination.display_page_controls assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type) assert isinstance(self.pagination.to_html(), six.text_type)
def test_second_page(self): def test_second_page(self):
request = Request(factory.get('/', {'page': 2})) request = Request(factory.get("/", {"page": 2}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10] assert queryset == [6, 7, 8, 9, 10]
assert content == { assert content == {
'results': [6, 7, 8, 9, 10], "results": [6, 7, 8, 9, 10],
'previous': 'http://testserver/', "previous": "http://testserver/",
'next': 'http://testserver/?page=3', "next": "http://testserver/?page=3",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/', "previous_url": "http://testserver/",
'next_url': 'http://testserver/?page=3', "next_url": "http://testserver/?page=3",
'page_links': [ "page_links": [
PageLink('http://testserver/', 1, False, False), PageLink("http://testserver/", 1, False, False),
PageLink('http://testserver/?page=2', 2, True, False), PageLink("http://testserver/?page=2", 2, True, False),
PageLink('http://testserver/?page=3', 3, False, False), PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?page=20', 20, False, False), PageLink("http://testserver/?page=20", 20, False, False),
] ],
} }
def test_last_page(self): def test_last_page(self):
request = Request(factory.get('/', {'page': 'last'})) request = Request(factory.get("/", {"page": "last"}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100] assert queryset == [96, 97, 98, 99, 100]
assert content == { assert content == {
'results': [96, 97, 98, 99, 100], "results": [96, 97, 98, 99, 100],
'previous': 'http://testserver/?page=19', "previous": "http://testserver/?page=19",
'next': None, "next": None,
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/?page=19', "previous_url": "http://testserver/?page=19",
'next_url': None, "next_url": None,
'page_links': [ "page_links": [
PageLink('http://testserver/', 1, False, False), PageLink("http://testserver/", 1, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?page=18', 18, False, False), PageLink("http://testserver/?page=18", 18, False, False),
PageLink('http://testserver/?page=19', 19, False, False), PageLink("http://testserver/?page=19", 19, False, False),
PageLink('http://testserver/?page=20', 20, True, False), PageLink("http://testserver/?page=20", 20, True, False),
] ],
} }
def test_invalid_page(self): def test_invalid_page(self):
request = Request(factory.get('/', {'page': 'invalid'})) request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound): with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request) self.paginate_queryset(request)
@ -295,29 +315,22 @@ class TestPageNumberPaginationOverride:
return self.pagination.get_html_context() return self.pagination.get_html_context()
def test_no_page_number(self): def test_no_page_number(self):
request = Request(factory.get('/')) request = Request(factory.get("/"))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [1] assert queryset == [1]
assert content == { assert content == {"results": [1], "previous": None, "next": None, "count": 1}
'results': [1, ],
'previous': None,
'next': None,
'count': 1
}
assert context == { assert context == {
'previous_url': None, "previous_url": None,
'next_url': None, "next_url": None,
'page_links': [ "page_links": [PageLink("http://testserver/", 1, True, False)],
PageLink('http://testserver/', 1, True, False),
]
} }
assert not self.pagination.display_page_controls assert not self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type) assert isinstance(self.pagination.to_html(), six.text_type)
def test_invalid_page(self): def test_invalid_page(self):
request = Request(factory.get('/', {'page': 'invalid'})) request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound): with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request) self.paginate_queryset(request)
@ -346,27 +359,27 @@ class TestLimitOffset:
return self.pagination.get_html_context() return self.pagination.get_html_context()
def test_no_offset(self): def test_no_offset(self):
request = Request(factory.get('/', {'limit': 5})) request = Request(factory.get("/", {"limit": 5}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5] assert queryset == [1, 2, 3, 4, 5]
assert content == { assert content == {
'results': [1, 2, 3, 4, 5], "results": [1, 2, 3, 4, 5],
'previous': None, "previous": None,
'next': 'http://testserver/?limit=5&offset=5', "next": "http://testserver/?limit=5&offset=5",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': None, "previous_url": None,
'next_url': 'http://testserver/?limit=5&offset=5', "next_url": "http://testserver/?limit=5&offset=5",
'page_links': [ "page_links": [
PageLink('http://testserver/?limit=5', 1, True, False), PageLink("http://testserver/?limit=5", 1, True, False),
PageLink('http://testserver/?limit=5&offset=5', 2, False, False), PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
PageLink('http://testserver/?limit=5&offset=10', 3, False, False), PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False), PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
] ],
} }
assert self.pagination.display_page_controls assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type) assert isinstance(self.pagination.to_html(), six.text_type)
@ -374,7 +387,8 @@ class TestLimitOffset:
def test_pagination_not_applied_if_limit_or_default_limit_not_set(self): def test_pagination_not_applied_if_limit_or_default_limit_not_set(self):
class MockPagination(pagination.LimitOffsetPagination): class MockPagination(pagination.LimitOffsetPagination):
default_limit = None default_limit = None
request = Request(factory.get('/'))
request = Request(factory.get("/"))
queryset = MockPagination().paginate_queryset(self.queryset, request) queryset = MockPagination().paginate_queryset(self.queryset, request)
assert queryset is None assert queryset is None
@ -384,104 +398,104 @@ class TestLimitOffset:
* The first page should still be offset zero. * The first page should still be offset zero.
* We may end up displaying an extra page in the pagination control. * We may end up displaying an extra page in the pagination control.
""" """
request = Request(factory.get('/', {'limit': 5, 'offset': 1})) request = Request(factory.get("/", {"limit": 5, "offset": 1}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [2, 3, 4, 5, 6] assert queryset == [2, 3, 4, 5, 6]
assert content == { assert content == {
'results': [2, 3, 4, 5, 6], "results": [2, 3, 4, 5, 6],
'previous': 'http://testserver/?limit=5', "previous": "http://testserver/?limit=5",
'next': 'http://testserver/?limit=5&offset=6', "next": "http://testserver/?limit=5&offset=6",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/?limit=5', "previous_url": "http://testserver/?limit=5",
'next_url': 'http://testserver/?limit=5&offset=6', "next_url": "http://testserver/?limit=5&offset=6",
'page_links': [ "page_links": [
PageLink('http://testserver/?limit=5', 1, False, False), PageLink("http://testserver/?limit=5", 1, False, False),
PageLink('http://testserver/?limit=5&offset=1', 2, True, False), PageLink("http://testserver/?limit=5&offset=1", 2, True, False),
PageLink('http://testserver/?limit=5&offset=6', 3, False, False), PageLink("http://testserver/?limit=5&offset=6", 3, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=96', 21, False, False), PageLink("http://testserver/?limit=5&offset=96", 21, False, False),
] ],
} }
def test_first_offset(self): def test_first_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 5})) request = Request(factory.get("/", {"limit": 5, "offset": 5}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10] assert queryset == [6, 7, 8, 9, 10]
assert content == { assert content == {
'results': [6, 7, 8, 9, 10], "results": [6, 7, 8, 9, 10],
'previous': 'http://testserver/?limit=5', "previous": "http://testserver/?limit=5",
'next': 'http://testserver/?limit=5&offset=10', "next": "http://testserver/?limit=5&offset=10",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/?limit=5', "previous_url": "http://testserver/?limit=5",
'next_url': 'http://testserver/?limit=5&offset=10', "next_url": "http://testserver/?limit=5&offset=10",
'page_links': [ "page_links": [
PageLink('http://testserver/?limit=5', 1, False, False), PageLink("http://testserver/?limit=5", 1, False, False),
PageLink('http://testserver/?limit=5&offset=5', 2, True, False), PageLink("http://testserver/?limit=5&offset=5", 2, True, False),
PageLink('http://testserver/?limit=5&offset=10', 3, False, False), PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False), PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
] ],
} }
def test_middle_offset(self): def test_middle_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 10})) request = Request(factory.get("/", {"limit": 5, "offset": 10}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [11, 12, 13, 14, 15] assert queryset == [11, 12, 13, 14, 15]
assert content == { assert content == {
'results': [11, 12, 13, 14, 15], "results": [11, 12, 13, 14, 15],
'previous': 'http://testserver/?limit=5&offset=5', "previous": "http://testserver/?limit=5&offset=5",
'next': 'http://testserver/?limit=5&offset=15', "next": "http://testserver/?limit=5&offset=15",
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/?limit=5&offset=5', "previous_url": "http://testserver/?limit=5&offset=5",
'next_url': 'http://testserver/?limit=5&offset=15', "next_url": "http://testserver/?limit=5&offset=15",
'page_links': [ "page_links": [
PageLink('http://testserver/?limit=5', 1, False, False), PageLink("http://testserver/?limit=5", 1, False, False),
PageLink('http://testserver/?limit=5&offset=5', 2, False, False), PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
PageLink('http://testserver/?limit=5&offset=10', 3, True, False), PageLink("http://testserver/?limit=5&offset=10", 3, True, False),
PageLink('http://testserver/?limit=5&offset=15', 4, False, False), PageLink("http://testserver/?limit=5&offset=15", 4, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False), PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
] ],
} }
def test_ending_offset(self): def test_ending_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 95})) request = Request(factory.get("/", {"limit": 5, "offset": 95}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
context = self.get_html_context() context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100] assert queryset == [96, 97, 98, 99, 100]
assert content == { assert content == {
'results': [96, 97, 98, 99, 100], "results": [96, 97, 98, 99, 100],
'previous': 'http://testserver/?limit=5&offset=90', "previous": "http://testserver/?limit=5&offset=90",
'next': None, "next": None,
'count': 100 "count": 100,
} }
assert context == { assert context == {
'previous_url': 'http://testserver/?limit=5&offset=90', "previous_url": "http://testserver/?limit=5&offset=90",
'next_url': None, "next_url": None,
'page_links': [ "page_links": [
PageLink('http://testserver/?limit=5', 1, False, False), PageLink("http://testserver/?limit=5", 1, False, False),
PAGE_BREAK, PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=85', 18, False, False), PageLink("http://testserver/?limit=5&offset=85", 18, False, False),
PageLink('http://testserver/?limit=5&offset=90', 19, False, False), PageLink("http://testserver/?limit=5&offset=90", 19, False, False),
PageLink('http://testserver/?limit=5&offset=95', 20, True, False), PageLink("http://testserver/?limit=5&offset=95", 20, True, False),
] ],
} }
def test_erronous_offset(self): def test_erronous_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 1000})) request = Request(factory.get("/", {"limit": 5, "offset": 1000}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
self.get_paginated_content(queryset) self.get_paginated_content(queryset)
self.get_html_context() self.get_html_context()
@ -490,7 +504,7 @@ class TestLimitOffset:
""" """
An invalid offset query param should be treated as 0. An invalid offset query param should be treated as 0.
""" """
request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'})) request = Request(factory.get("/", {"limit": 5, "offset": "invalid"}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
assert queryset == [1, 2, 3, 4, 5] assert queryset == [1, 2, 3, 4, 5]
@ -498,27 +512,31 @@ class TestLimitOffset:
""" """
An invalid limit query param should be ignored in favor of the default. An invalid limit query param should be ignored in favor of the default.
""" """
request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0})) request = Request(factory.get("/", {"limit": "invalid", "offset": 0}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit next_offset = self.pagination.default_limit
next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset) next_url = "http://testserver/?limit={0}&offset={1}".format(
next_limit, next_offset
)
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
assert content.get('next') == next_url assert content.get("next") == next_url
def test_zero_limit(self): def test_zero_limit(self):
""" """
An zero limit query param should be ignored in favor of the default. An zero limit query param should be ignored in favor of the default.
""" """
request = Request(factory.get('/', {'limit': 0, 'offset': 0})) request = Request(factory.get("/", {"limit": 0, "offset": 0}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit next_offset = self.pagination.default_limit
next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset) next_url = "http://testserver/?limit={0}&offset={1}".format(
next_limit, next_offset
)
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
assert content.get('next') == next_url assert content.get("next") == next_url
def test_max_limit(self): def test_max_limit(self):
""" """
@ -526,47 +544,46 @@ class TestLimitOffset:
requested limit is greater than the max_limit requested limit is greater than the max_limit
""" """
offset = 50 offset = 50
request = Request(factory.get('/', {'limit': '11235', 'offset': offset})) request = Request(factory.get("/", {"limit": "11235", "offset": offset}))
queryset = self.paginate_queryset(request) queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset) content = self.get_paginated_content(queryset)
max_limit = self.pagination.max_limit max_limit = self.pagination.max_limit
next_offset = offset + max_limit next_offset = offset + max_limit
prev_offset = offset - max_limit prev_offset = offset - max_limit
base_url = 'http://testserver/?limit={0}'.format(max_limit) base_url = "http://testserver/?limit={0}".format(max_limit)
next_url = base_url + '&offset={0}'.format(next_offset) next_url = base_url + "&offset={0}".format(next_offset)
prev_url = base_url + '&offset={0}'.format(prev_offset) prev_url = base_url + "&offset={0}".format(prev_offset)
assert queryset == list(range(51, 66)) assert queryset == list(range(51, 66))
assert content.get('next') == next_url assert content.get("next") == next_url
assert content.get('previous') == prev_url assert content.get("previous") == prev_url
class CursorPaginationTestsMixin: class CursorPaginationTestsMixin:
def test_invalid_cursor(self): def test_invalid_cursor(self):
request = Request(factory.get('/', {'cursor': '123'})) request = Request(factory.get("/", {"cursor": "123"}))
with pytest.raises(exceptions.NotFound): with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request) self.pagination.paginate_queryset(self.queryset, request)
def test_use_with_ordering_filter(self): def test_use_with_ordering_filter(self):
class MockView: class MockView:
filter_backends = (filters.OrderingFilter,) filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created'] ordering_fields = ["username", "created"]
ordering = 'created' ordering = "created"
request = Request(factory.get('/', {'ordering': 'username'})) request = Request(factory.get("/", {"ordering": "username"}))
ordering = self.pagination.get_ordering(request, [], MockView()) ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',) assert ordering == ("username",)
request = Request(factory.get('/', {'ordering': '-username'})) request = Request(factory.get("/", {"ordering": "-username"}))
ordering = self.pagination.get_ordering(request, [], MockView()) ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('-username',) assert ordering == ("-username",)
request = Request(factory.get('/', {'ordering': 'invalid'})) request = Request(factory.get("/", {"ordering": "invalid"}))
ordering = self.pagination.get_ordering(request, [], MockView()) ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',) assert ordering == ("created",)
def test_cursor_pagination(self): def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/') (previous, current, next, previous_url, next_url) = self.get_pages("/")
assert previous is None assert previous is None
assert current == [1, 1, 1, 1, 1] assert current == [1, 1, 1, 1, 1]
@ -635,7 +652,9 @@ class CursorPaginationTestsMixin:
assert isinstance(self.pagination.to_html(), six.text_type) assert isinstance(self.pagination.to_html(), six.text_type)
def test_cursor_pagination_with_page_size(self): def test_cursor_pagination_with_page_size(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=20') (previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=20"
)
assert previous is None assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7] assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@ -647,7 +666,9 @@ class CursorPaginationTestsMixin:
assert next is None assert next is None
def test_cursor_pagination_with_page_size_over_limit(self): def test_cursor_pagination_with_page_size_over_limit(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=30') (previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=30"
)
assert previous is None assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7] assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@ -659,7 +680,9 @@ class CursorPaginationTestsMixin:
assert next is None assert next is None
def test_cursor_pagination_with_page_size_zero(self): def test_cursor_pagination_with_page_size_zero(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=0') (previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=0"
)
assert previous is None assert previous is None
assert current == [1, 1, 1, 1, 1] assert current == [1, 1, 1, 1, 1]
@ -726,7 +749,9 @@ class CursorPaginationTestsMixin:
assert next == [1, 2, 3, 4, 4] assert next == [1, 2, 3, 4, 4]
def test_cursor_pagination_with_page_size_negative(self): def test_cursor_pagination_with_page_size_negative(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=-5') (previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=-5"
)
assert previous is None assert previous is None
assert current == [1, 1, 1, 1, 1] assert current == [1, 1, 1, 1, 1]
@ -809,19 +834,17 @@ class TestCursorPagination(CursorPaginationTestsMixin):
def filter(self, created__gt=None, created__lt=None): def filter(self, created__gt=None, created__lt=None):
if created__gt is not None: if created__gt is not None:
return MockQuerySet([ return MockQuerySet(
item for item in self.items [item for item in self.items if item.created > int(created__gt)]
if item.created > int(created__gt) )
])
assert created__lt is not None assert created__lt is not None
return MockQuerySet([ return MockQuerySet(
item for item in self.items [item for item in self.items if item.created < int(created__lt)]
if item.created < int(created__lt) )
])
def order_by(self, *ordering): def order_by(self, *ordering):
if ordering[0].startswith('-'): if ordering[0].startswith("-"):
return MockQuerySet(list(reversed(self.items))) return MockQuerySet(list(reversed(self.items)))
return self return self
@ -830,21 +853,48 @@ class TestCursorPagination(CursorPaginationTestsMixin):
class ExamplePagination(pagination.CursorPagination): class ExamplePagination(pagination.CursorPagination):
page_size = 5 page_size = 5
page_size_query_param = 'page_size' page_size_query_param = "page_size"
max_page_size = 20 max_page_size = 20
ordering = 'created' ordering = "created"
self.pagination = ExamplePagination() self.pagination = ExamplePagination()
self.queryset = MockQuerySet([ self.queryset = MockQuerySet(
MockObject(idx) for idx in [ [
1, 1, 1, 1, 1, MockObject(idx)
1, 2, 3, 4, 4, for idx in [
4, 4, 5, 6, 7, 1,
7, 7, 7, 7, 7, 1,
7, 7, 7, 8, 9, 1,
9, 9, 9, 9, 9 1,
1,
1,
2,
3,
4,
4,
4,
4,
5,
6,
7,
7,
7,
7,
7,
7,
7,
7,
7,
8,
9,
9,
9,
9,
9,
9,
]
] ]
]) )
def get_pages(self, url): def get_pages(self, url):
""" """
@ -888,18 +938,42 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
def setUp(self): def setUp(self):
class ExamplePagination(pagination.CursorPagination): class ExamplePagination(pagination.CursorPagination):
page_size = 5 page_size = 5
page_size_query_param = 'page_size' page_size_query_param = "page_size"
max_page_size = 20 max_page_size = 20
ordering = 'created' ordering = "created"
self.pagination = ExamplePagination() self.pagination = ExamplePagination()
data = [ data = [
1, 1, 1, 1, 1, 1,
1, 2, 3, 4, 4, 1,
4, 4, 5, 6, 7, 1,
7, 7, 7, 7, 7, 1,
7, 7, 7, 8, 9, 1,
9, 9, 9, 9, 9 1,
2,
3,
4,
4,
4,
4,
5,
6,
7,
7,
7,
7,
7,
7,
7,
7,
7,
8,
9,
9,
9,
9,
9,
9,
] ]
for idx in data: for idx in data:
CursorPaginationModel.objects.create(created=idx) CursorPaginationModel.objects.create(created=idx)
@ -914,7 +988,7 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
""" """
request = Request(factory.get(url)) request = Request(factory.get(url))
queryset = self.pagination.paginate_queryset(self.queryset, request) queryset = self.pagination.paginate_queryset(self.queryset, request)
current = [item['created'] for item in queryset] current = [item["created"] for item in queryset]
next_url = self.pagination.get_next_link() next_url = self.pagination.get_next_link()
previous_url = self.pagination.get_previous_link() previous_url = self.pagination.get_previous_link()
@ -922,14 +996,14 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
if next_url is not None: if next_url is not None:
request = Request(factory.get(next_url)) request = Request(factory.get(next_url))
queryset = self.pagination.paginate_queryset(self.queryset, request) queryset = self.pagination.paginate_queryset(self.queryset, request)
next = [item['created'] for item in queryset] next = [item["created"] for item in queryset]
else: else:
next = None next = None
if previous_url is not None: if previous_url is not None:
request = Request(factory.get(previous_url)) request = Request(factory.get(previous_url))
queryset = self.pagination.paginate_queryset(self.queryset, request) queryset = self.pagination.paginate_queryset(self.queryset, request)
previous = [item['created'] for item in queryset] previous = [item["created"] for item in queryset]
else: else:
previous = None previous = None

View File

@ -7,7 +7,8 @@ import math
import pytest import pytest
from django import forms from django import forms
from django.core.files.uploadhandler import ( from django.core.files.uploadhandler import (
MemoryFileUploadHandler, TemporaryFileUploadHandler MemoryFileUploadHandler,
TemporaryFileUploadHandler,
) )
from django.http.request import RawPostDataException from django.http.request import RawPostDataException
from django.test import TestCase from django.test import TestCase
@ -15,7 +16,10 @@ from django.utils.six import StringIO
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.parsers import ( from rest_framework.parsers import (
FileUploadParser, FormParser, JSONParser, MultiPartParser FileUploadParser,
FormParser,
JSONParser,
MultiPartParser,
) )
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -44,16 +48,15 @@ class TestFileUploadParser(TestCase):
def setUp(self): def setUp(self):
class MockRequest(object): class MockRequest(object):
pass pass
self.stream = io.BytesIO(
"Test text file".encode('utf-8') self.stream = io.BytesIO("Test text file".encode("utf-8"))
)
request = MockRequest() request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),) request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = { request.META = {
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt', "HTTP_CONTENT_DISPOSITION": "Content-Disposition: inline; filename=file.txt",
'HTTP_CONTENT_LENGTH': 14, "HTTP_CONTENT_LENGTH": 14,
} }
self.parser_context = {'request': request, 'kwargs': {}} self.parser_context = {"request": request, "kwargs": {}}
def test_parse(self): def test_parse(self):
""" """
@ -62,7 +65,7 @@ class TestFileUploadParser(TestCase):
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context) data_and_files = parser.parse(self.stream, None, self.parser_context)
file_obj = data_and_files.files['file'] file_obj = data_and_files.files["file"]
assert file_obj.size == 14 assert file_obj.size == 14
def test_parse_missing_filename(self): def test_parse_missing_filename(self):
@ -71,10 +74,13 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_parse_missing_filename_multiple_upload_handlers(self): def test_parse_missing_filename_multiple_upload_handlers(self):
""" """
@ -83,14 +89,17 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context["request"].upload_handlers = (
MemoryFileUploadHandler(),
MemoryFileUploadHandler(), MemoryFileUploadHandler(),
MemoryFileUploadHandler()
) )
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = '' self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_parse_missing_filename_large_file(self): def test_parse_missing_filename_large_file(self):
""" """
@ -98,54 +107,59 @@ class TestFileUploadParser(TestCase):
""" """
parser = FileUploadParser() parser = FileUploadParser()
self.stream.seek(0) self.stream.seek(0)
self.parser_context['request'].upload_handlers = ( self.parser_context["request"].upload_handlers = (TemporaryFileUploadHandler(),)
TemporaryFileUploadHandler(), self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
with pytest.raises(ParseError) as excinfo: with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context) parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.' assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_get_filename(self): def test_get_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'file.txt' assert filename == "file.txt"
def test_get_encoded_filename(self): def test_get_encoded_filename(self):
parser = FileUploadParser() parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition("inline; filename*=utf-8''ÀĥƦ.txt")
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == "ÀĥƦ.txt"
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt') self.__replace_content_disposition(
"inline; filename=fallback.txt; filename*=utf-8''ÀĥƦ.txt"
)
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == "ÀĥƦ.txt"
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt') self.__replace_content_disposition(
"inline; filename=fallback.txt; filename*=utf-8'en-us'ÀĥƦ.txt"
)
filename = parser.get_filename(self.stream, None, self.parser_context) filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt' assert filename == "ÀĥƦ.txt"
def __replace_content_disposition(self, disposition): def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = disposition
class TestJSONParser(TestCase): class TestJSONParser(TestCase):
def bytes(self, value): def bytes(self, value):
return io.BytesIO(value.encode('utf-8')) return io.BytesIO(value.encode("utf-8"))
def test_float_strictness(self): def test_float_strictness(self):
parser = JSONParser() parser = JSONParser()
# Default to strict # Default to strict
for value in ['Infinity', '-Infinity', 'NaN']: for value in ["Infinity", "-Infinity", "NaN"]:
with pytest.raises(ParseError): with pytest.raises(ParseError):
parser.parse(self.bytes(value)) parser.parse(self.bytes(value))
parser.strict = False parser.strict = False
assert parser.parse(self.bytes('Infinity')) == float('inf') assert parser.parse(self.bytes("Infinity")) == float("inf")
assert parser.parse(self.bytes('-Infinity')) == float('-inf') assert parser.parse(self.bytes("-Infinity")) == float("-inf")
assert math.isnan(parser.parse(self.bytes('NaN'))) assert math.isnan(parser.parse(self.bytes("NaN")))
class TestPOSTAccessed(TestCase): class TestPOSTAccessed(TestCase):
@ -153,28 +167,28 @@ class TestPOSTAccessed(TestCase):
self.factory = APIRequestFactory() self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self): def test_post_accessed_in_post_method(self):
django_request = self.factory.post('/', {'foo': 'bar'}) django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST django_request.POST
assert request.POST == {'foo': ['bar']} assert request.POST == {"foo": ["bar"]}
assert request.data == {'foo': ['bar']} assert request.data == {"foo": ["bar"]}
def test_post_accessed_in_post_method_with_json_parser(self): def test_post_accessed_in_post_method_with_json_parser(self):
django_request = self.factory.post('/', {'foo': 'bar'}) django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[JSONParser()]) request = Request(django_request, parsers=[JSONParser()])
django_request.POST django_request.POST
assert request.POST == {} assert request.POST == {}
assert request.data == {} assert request.data == {}
def test_post_accessed_in_put_method(self): def test_post_accessed_in_put_method(self):
django_request = self.factory.put('/', {'foo': 'bar'}) django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST django_request.POST
assert request.POST == {'foo': ['bar']} assert request.POST == {"foo": ["bar"]}
assert request.data == {'foo': ['bar']} assert request.data == {"foo": ["bar"]}
def test_request_read_before_parsing(self): def test_request_read_before_parsing(self):
django_request = self.factory.put('/', {'foo': 'bar'}) django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()]) request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read() django_request.read()
with pytest.raises(RawPostDataException): with pytest.raises(RawPostDataException):

Some files were not shown because too many files have changed in this diff Show More