diff --git a/requirements/requirements-codestyle.txt b/requirements/requirements-codestyle.txt index 2b7bad436..fcdfb54e8 100644 --- a/requirements/requirements-codestyle.txt +++ b/requirements/requirements-codestyle.txt @@ -4,7 +4,7 @@ flake8-tidy-imports==1.1.0 pycodestyle==2.3.1 # Sort and lint imports -isort==4.3.3 +isort==4.3.17 # black black==19.3b0 \ No newline at end of file diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 55c06982d..cd22cd368 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -7,22 +7,22 @@ ______ _____ _____ _____ __ \_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_| """ -__title__ = 'Django REST framework' -__version__ = '3.9.2' -__author__ = 'Tom Christie' -__license__ = 'BSD 2-Clause' -__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd' +__title__ = "Django REST framework" +__version__ = "3.9.2" +__author__ = "Tom Christie" +__license__ = "BSD 2-Clause" +__copyright__ = "Copyright 2011-2019 Encode OSS Ltd" # Version synonym VERSION = __version__ # Header encoding (see RFC5987) -HTTP_HEADER_ENCODING = 'iso-8859-1' +HTTP_HEADER_ENCODING = "iso-8859-1" # Default datetime input and output formats -ISO_8601 = 'iso-8601' +ISO_8601 = "iso-8601" -default_app_config = 'rest_framework.apps.RestFrameworkConfig' +default_app_config = "rest_framework.apps.RestFrameworkConfig" class RemovedInDRF310Warning(DeprecationWarning): diff --git a/rest_framework/apps.py b/rest_framework/apps.py index f6013eb7e..af2a09038 100644 --- a/rest_framework/apps.py +++ b/rest_framework/apps.py @@ -2,7 +2,7 @@ from django.apps import AppConfig class RestFrameworkConfig(AppConfig): - name = 'rest_framework' + name = "rest_framework" verbose_name = "Django REST framework" def ready(self): diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 25150d525..19b74f115 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -20,7 +20,7 @@ def get_authorization_header(request): Hide some test client ickyness where the header can be unicode. """ - auth = request.META.get('HTTP_AUTHORIZATION', b'') + auth = request.META.get("HTTP_AUTHORIZATION", b"") if isinstance(auth, text_type): # Work around django test client oddness auth = auth.encode(HTTP_HEADER_ENCODING) @@ -57,7 +57,8 @@ class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ - www_authenticate_realm = 'api' + + www_authenticate_realm = "api" def authenticate(self, request): """ @@ -66,20 +67,24 @@ class BasicAuthentication(BaseAuthentication): """ auth = get_authorization_header(request).split() - if not auth or auth[0].lower() != b'basic': + if not auth or auth[0].lower() != b"basic": return None if len(auth) == 1: - msg = _('Invalid basic header. No credentials provided.') + msg = _("Invalid basic header. No credentials provided.") raise exceptions.AuthenticationFailed(msg) elif len(auth) > 2: - msg = _('Invalid basic header. Credentials string should not contain spaces.') + msg = _( + "Invalid basic header. Credentials string should not contain spaces." + ) raise exceptions.AuthenticationFailed(msg) try: - auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':') + auth_parts = ( + base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(":") + ) except (TypeError, UnicodeDecodeError, binascii.Error): - msg = _('Invalid basic header. Credentials not correctly base64 encoded.') + msg = _("Invalid basic header. Credentials not correctly base64 encoded.") raise exceptions.AuthenticationFailed(msg) userid, password = auth_parts[0], auth_parts[2] @@ -90,17 +95,14 @@ class BasicAuthentication(BaseAuthentication): Authenticate the userid and password against username and password with optional request for context. """ - credentials = { - get_user_model().USERNAME_FIELD: userid, - 'password': password - } + credentials = {get_user_model().USERNAME_FIELD: userid, "password": password} user = authenticate(request=request, **credentials) if user is None: - raise exceptions.AuthenticationFailed(_('Invalid username/password.')) + raise exceptions.AuthenticationFailed(_("Invalid username/password.")) if not user.is_active: - raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) + raise exceptions.AuthenticationFailed(_("User inactive or deleted.")) return (user, None) @@ -120,7 +122,7 @@ class SessionAuthentication(BaseAuthentication): """ # Get the session-based user from the underlying HttpRequest object - user = getattr(request._request, 'user', None) + user = getattr(request._request, "user", None) # Unauthenticated, CSRF validation not required if not user or not user.is_active: @@ -141,7 +143,7 @@ class SessionAuthentication(BaseAuthentication): reason = check.process_view(request, None, (), {}) if reason: # CSRF failed, bail with explicit error message - raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) + raise exceptions.PermissionDenied("CSRF Failed: %s" % reason) class TokenAuthentication(BaseAuthentication): @@ -154,13 +156,14 @@ class TokenAuthentication(BaseAuthentication): Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a """ - keyword = 'Token' + keyword = "Token" model = None def get_model(self): if self.model is not None: return self.model from rest_framework.authtoken.models import Token + return Token """ @@ -177,16 +180,18 @@ class TokenAuthentication(BaseAuthentication): return None if len(auth) == 1: - msg = _('Invalid token header. No credentials provided.') + msg = _("Invalid token header. No credentials provided.") raise exceptions.AuthenticationFailed(msg) elif len(auth) > 2: - msg = _('Invalid token header. Token string should not contain spaces.') + msg = _("Invalid token header. Token string should not contain spaces.") raise exceptions.AuthenticationFailed(msg) try: token = auth[1].decode() except UnicodeError: - msg = _('Invalid token header. Token string should not contain invalid characters.') + msg = _( + "Invalid token header. Token string should not contain invalid characters." + ) raise exceptions.AuthenticationFailed(msg) return self.authenticate_credentials(token) @@ -194,12 +199,12 @@ class TokenAuthentication(BaseAuthentication): def authenticate_credentials(self, key): model = self.get_model() try: - token = model.objects.select_related('user').get(key=key) + token = model.objects.select_related("user").get(key=key) except model.DoesNotExist: - raise exceptions.AuthenticationFailed(_('Invalid token.')) + raise exceptions.AuthenticationFailed(_("Invalid token.")) if not token.user.is_active: - raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) + raise exceptions.AuthenticationFailed(_("User inactive or deleted.")) return (token.user, token) diff --git a/rest_framework/authtoken/__init__.py b/rest_framework/authtoken/__init__.py index 82f5b9171..bc19a2e04 100644 --- a/rest_framework/authtoken/__init__.py +++ b/rest_framework/authtoken/__init__.py @@ -1 +1 @@ -default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig' +default_app_config = "rest_framework.authtoken.apps.AuthTokenConfig" diff --git a/rest_framework/authtoken/admin.py b/rest_framework/authtoken/admin.py index 1a507249b..f2ca70ec2 100644 --- a/rest_framework/authtoken/admin.py +++ b/rest_framework/authtoken/admin.py @@ -4,9 +4,9 @@ from rest_framework.authtoken.models import Token class TokenAdmin(admin.ModelAdmin): - list_display = ('key', 'user', 'created') - fields = ('user',) - ordering = ('-created',) + list_display = ("key", "user", "created") + fields = ("user",) + ordering = ("-created",) admin.site.register(Token, TokenAdmin) diff --git a/rest_framework/authtoken/apps.py b/rest_framework/authtoken/apps.py index ad01cb404..7b2aac0c6 100644 --- a/rest_framework/authtoken/apps.py +++ b/rest_framework/authtoken/apps.py @@ -3,5 +3,5 @@ from django.utils.translation import ugettext_lazy as _ class AuthTokenConfig(AppConfig): - name = 'rest_framework.authtoken' + name = "rest_framework.authtoken" verbose_name = _("Auth Token") diff --git a/rest_framework/authtoken/management/commands/drf_create_token.py b/rest_framework/authtoken/management/commands/drf_create_token.py index 8e06812db..5dc41a97d 100644 --- a/rest_framework/authtoken/management/commands/drf_create_token.py +++ b/rest_framework/authtoken/management/commands/drf_create_token.py @@ -3,11 +3,12 @@ from django.core.management.base import BaseCommand, CommandError from rest_framework.authtoken.models import Token + UserModel = get_user_model() class Command(BaseCommand): - help = 'Create DRF Token for a given user' + help = "Create DRF Token for a given user" def create_user_token(self, username, reset_token): user = UserModel._default_manager.get_by_natural_key(username) @@ -19,27 +20,27 @@ class Command(BaseCommand): return token[0] def add_arguments(self, parser): - parser.add_argument('username', type=str) + parser.add_argument("username", type=str) parser.add_argument( - '-r', - '--reset', - action='store_true', - dest='reset_token', + "-r", + "--reset", + action="store_true", + dest="reset_token", default=False, - help='Reset existing User token and create a new one', + help="Reset existing User token and create a new one", ) def handle(self, *args, **options): - username = options['username'] - reset_token = options['reset_token'] + username = options["username"] + reset_token = options["reset_token"] try: token = self.create_user_token(username, reset_token) except UserModel.DoesNotExist: raise CommandError( - 'Cannot create the Token: user {0} does not exist'.format( - username) + "Cannot create the Token: user {0} does not exist".format(username) ) self.stdout.write( - 'Generated token {0} for user {1}'.format(token.key, username)) + "Generated token {0} for user {1}".format(token.key, username) + ) diff --git a/rest_framework/authtoken/migrations/0001_initial.py b/rest_framework/authtoken/migrations/0001_initial.py index 75780fedf..708bade1a 100644 --- a/rest_framework/authtoken/migrations/0001_initial.py +++ b/rest_framework/authtoken/migrations/0001_initial.py @@ -7,20 +7,27 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] + dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)] operations = [ migrations.CreateModel( - name='Token', + name="Token", fields=[ - ('key', models.CharField(primary_key=True, serialize=False, max_length=40)), - ('created', models.DateTimeField(auto_now_add=True)), - ('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token', on_delete=models.CASCADE)), + ( + "key", + models.CharField(primary_key=True, serialize=False, max_length=40), + ), + ("created", models.DateTimeField(auto_now_add=True)), + ( + "user", + models.OneToOneField( + to=settings.AUTH_USER_MODEL, + related_name="auth_token", + on_delete=models.CASCADE, + ), + ), ], - options={ - }, + options={}, bases=(models.Model,), - ), + ) ] diff --git a/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py b/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py index 9f7e58e22..ac404c764 100644 --- a/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py +++ b/rest_framework/authtoken/migrations/0002_auto_20160226_1747.py @@ -7,28 +7,33 @@ from django.db import migrations, models class Migration(migrations.Migration): - dependencies = [ - ('authtoken', '0001_initial'), - ] + dependencies = [("authtoken", "0001_initial")] operations = [ migrations.AlterModelOptions( - name='token', - options={'verbose_name_plural': 'Tokens', 'verbose_name': 'Token'}, + name="token", + options={"verbose_name_plural": "Tokens", "verbose_name": "Token"}, ), migrations.AlterField( - model_name='token', - name='created', - field=models.DateTimeField(verbose_name='Created', auto_now_add=True), + model_name="token", + name="created", + field=models.DateTimeField(verbose_name="Created", auto_now_add=True), ), migrations.AlterField( - model_name='token', - name='key', - field=models.CharField(verbose_name='Key', max_length=40, primary_key=True, serialize=False), + model_name="token", + name="key", + field=models.CharField( + verbose_name="Key", max_length=40, primary_key=True, serialize=False + ), ), migrations.AlterField( - model_name='token', - name='user', - field=models.OneToOneField(to=settings.AUTH_USER_MODEL, verbose_name='User', related_name='auth_token', on_delete=models.CASCADE), + model_name="token", + name="user", + field=models.OneToOneField( + to=settings.AUTH_USER_MODEL, + verbose_name="User", + related_name="auth_token", + on_delete=models.CASCADE, + ), ), ] diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 7e96eff93..d99a8a524 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -12,10 +12,13 @@ class Token(models.Model): """ The default authorization token model. """ + key = models.CharField(_("Key"), max_length=40, primary_key=True) user = models.OneToOneField( - settings.AUTH_USER_MODEL, related_name='auth_token', - on_delete=models.CASCADE, verbose_name=_("User") + settings.AUTH_USER_MODEL, + related_name="auth_token", + on_delete=models.CASCADE, + verbose_name=_("User"), ) created = models.DateTimeField(_("Created"), auto_now_add=True) @@ -25,7 +28,7 @@ class Token(models.Model): # # Also see corresponding ticket: # https://github.com/encode/django-rest-framework/issues/705 - abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS + abstract = "rest_framework.authtoken" not in settings.INSTALLED_APPS verbose_name = _("Token") verbose_name_plural = _("Tokens") diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index e5f46dd66..0587fe0a5 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -7,28 +7,29 @@ from rest_framework import serializers class AuthTokenSerializer(serializers.Serializer): username = serializers.CharField(label=_("Username")) password = serializers.CharField( - label=_("Password"), - style={'input_type': 'password'}, - trim_whitespace=False + label=_("Password"), style={"input_type": "password"}, trim_whitespace=False ) def validate(self, attrs): - username = attrs.get('username') - password = attrs.get('password') + username = attrs.get("username") + password = attrs.get("password") if username and password: - user = authenticate(request=self.context.get('request'), - username=username, password=password) + user = authenticate( + request=self.context.get("request"), + username=username, + password=password, + ) # The authenticate call simply returns None for is_active=False # users. (Assuming the default ModelBackend authentication # backend.) if not user: - msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg, code='authorization') + msg = _("Unable to log in with provided credentials.") + raise serializers.ValidationError(msg, code="authorization") else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg, code='authorization') + raise serializers.ValidationError(msg, code="authorization") - attrs['user'] = user + attrs["user"] = user return attrs diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index a8c751d51..f73cbb295 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -10,7 +10,7 @@ from rest_framework.views import APIView class ObtainAuthToken(APIView): throttle_classes = () permission_classes = () - parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) + parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser) renderer_classes = (renderers.JSONRenderer,) serializer_class = AuthTokenSerializer if coreapi is not None and coreschema is not None: @@ -19,7 +19,7 @@ class ObtainAuthToken(APIView): coreapi.Field( name="username", required=True, - location='form', + location="form", schema=coreschema.String( title="Username", description="Valid username for authentication", @@ -28,7 +28,7 @@ class ObtainAuthToken(APIView): coreapi.Field( name="password", required=True, - location='form', + location="form", schema=coreschema.String( title="Password", description="Valid password for authentication", @@ -39,12 +39,13 @@ class ObtainAuthToken(APIView): ) def post(self, request, *args, **kwargs): - serializer = self.serializer_class(data=request.data, - context={'request': request}) + serializer = self.serializer_class( + data=request.data, context={"request": request} + ) serializer.is_valid(raise_exception=True) - user = serializer.validated_data['user'] + user = serializer.validated_data["user"] token, created = Token.objects.get_or_create(user=user) - return Response({'token': token.key}) + return Response({"token": token.key}) obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/checks.py b/rest_framework/checks.py index c1e626018..fe17f0046 100644 --- a/rest_framework/checks.py +++ b/rest_framework/checks.py @@ -6,16 +6,17 @@ def pagination_system_check(app_configs, **kwargs): errors = [] # Use of default page size setting requires a default Paginator class from rest_framework.settings import api_settings + if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS: errors.append( Warning( "You have specified a default PAGE_SIZE pagination rest_framework setting," "without specifying also a DEFAULT_PAGINATION_CLASS.", hint="The default for DEFAULT_PAGINATION_CLASS is None. " - "In previous versions this was PageNumberPagination. " - "If you wish to define PAGE_SIZE globally whilst defining " - "pagination_class on a per-view basis you may silence this check.", - id="rest_framework.W001" + "In previous versions this was PageNumberPagination. " + "If you wish to define PAGE_SIZE globally whilst defining " + "pagination_class on a per-view basis you may silence this check.", + id="rest_framework.W001", ) ) return errors diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 9422e6ad5..9026c1357 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -12,18 +12,16 @@ from django.core import validators from django.utils import six from django.views.generic import View -try: - # Python 3 - from collections.abc import Mapping, MutableMapping # noqa -except ImportError: - # Python 2.7 - from collections import Mapping, MutableMapping # noqa try: - from django.urls import ( # noqa - URLPattern, - URLResolver, - ) + # Python 3 + from collections.abc import Mapping, MutableMapping # noqa +except ImportError: + # Python 2.7 + from collections import Mapping, MutableMapping # noqa + +try: + from django.urls import URLPattern, URLResolver # noqa except ImportError: # Will be removed in Django 2.0 from django.urls import ( # noqa @@ -47,7 +45,7 @@ def get_original_route(urlpattern): Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path(). """ - if hasattr(urlpattern, 'pattern'): + if hasattr(urlpattern, "pattern"): # Django 2.0 return str(urlpattern.pattern) else: @@ -60,7 +58,7 @@ def get_regex_pattern(urlpattern): Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression, unlike get_original_route above. """ - if hasattr(urlpattern, 'pattern'): + if hasattr(urlpattern, "pattern"): # Django 2.0 return urlpattern.pattern.regex.pattern else: @@ -69,9 +67,10 @@ def get_regex_pattern(urlpattern): def is_route_pattern(urlpattern): - if hasattr(urlpattern, 'pattern'): + if hasattr(urlpattern, "pattern"): # Django 2.0 from django.urls.resolvers import RoutePattern + return isinstance(urlpattern.pattern, RoutePattern) else: # Django < 2.0 @@ -82,6 +81,7 @@ def make_url_resolver(regex, urlpatterns): try: # Django 2.0 from django.urls.resolvers import RegexPattern + return URLResolver(RegexPattern(regex), urlpatterns) except ImportError: @@ -93,7 +93,7 @@ def unicode_repr(instance): # Get the repr of an instance, but ensure it is a unicode string # on both python 3 (already the case) and 2 (not the case). if six.PY2: - return repr(instance).decode('utf-8') + return repr(instance).decode("utf-8") return repr(instance) @@ -102,21 +102,21 @@ def unicode_to_repr(value): # the Python version. We wrap all our `__repr__` implementations with # this and then use unicode throughout internally. if six.PY2: - return value.encode('utf-8') + return value.encode("utf-8") return value def unicode_http_header(value): # Coerce HTTP header value to unicode. if isinstance(value, bytes): - return value.decode('iso-8859-1') + return value.decode("iso-8859-1") return value def distinct(queryset, base): if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle": # distinct analogue for Oracle users - return base.filter(pk__in=set(queryset.values_list('pk', flat=True))) + return base.filter(pk__in=set(queryset.values_list("pk", flat=True))) return queryset.distinct() @@ -172,27 +172,27 @@ def is_guardian_installed(): # Guardian 1.5.0, for Django 2.2 is NOT compatible with Python 2.7. # Remove when dropping PY2. return False - return 'guardian' in settings.INSTALLED_APPS + return "guardian" in settings.INSTALLED_APPS # PATCH method is not implemented by Django -if 'patch' not in View.http_method_names: - View.http_method_names = View.http_method_names + ['patch'] +if "patch" not in View.http_method_names: + View.http_method_names = View.http_method_names + ["patch"] # Markdown is optional try: import markdown - if markdown.version <= '2.2': - HEADERID_EXT_PATH = 'headerid' - LEVEL_PARAM = 'level' - elif markdown.version < '2.6': - HEADERID_EXT_PATH = 'markdown.extensions.headerid' - LEVEL_PARAM = 'level' + if markdown.version <= "2.2": + HEADERID_EXT_PATH = "headerid" + LEVEL_PARAM = "level" + elif markdown.version < "2.6": + HEADERID_EXT_PATH = "markdown.extensions.headerid" + LEVEL_PARAM = "level" else: - HEADERID_EXT_PATH = 'markdown.extensions.toc' - LEVEL_PARAM = 'baselevel' + HEADERID_EXT_PATH = "markdown.extensions.toc" + LEVEL_PARAM = "baselevel" def apply_markdown(text): """ @@ -200,16 +200,14 @@ try: of '#' style headers to

. """ extensions = [HEADERID_EXT_PATH] - extension_configs = { - HEADERID_EXT_PATH: { - LEVEL_PARAM: '2' - } - } + extension_configs = {HEADERID_EXT_PATH: {LEVEL_PARAM: "2"}} md = markdown.Markdown( extensions=extensions, extension_configs=extension_configs ) md_filter_add_syntax_highlight(md) return md.convert(text) + + except ImportError: apply_markdown = None markdown = None @@ -227,7 +225,8 @@ try: def pygments_css(style): formatter = HtmlFormatter(style=style) - return formatter.get_style_defs('.highlight') + return formatter.get_style_defs(".highlight") + except ImportError: pygments = None @@ -238,6 +237,7 @@ except ImportError: def pygments_css(style): return None + if markdown is not None and pygments is not None: # starting from this blogpost and modified to support current markdown extensions API # https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/ @@ -246,8 +246,7 @@ if markdown is not None and pygments is not None: import re class CodeBlockPreprocessor(Preprocessor): - pattern = re.compile( - r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S) + pattern = re.compile(r"^\s*``` *([^\n]+)\n(.+?)^\s*```", re.M | re.S) formatter = HtmlFormatter() @@ -257,17 +256,25 @@ if markdown is not None and pygments is not None: lexer = get_lexer_by_name(m.group(1)) except (ValueError, NameError): lexer = TextLexer() - code = m.group(2).replace('\t', ' ') + code = m.group(2).replace("\t", " ") code = pygments.highlight(code, lexer, self.formatter) - code = code.replace('\n\n', '\n \n').replace('\n', '
').replace('\\@', '@') - return '\n\n%s\n\n' % code + code = ( + code.replace("\n\n", "\n \n") + .replace("\n", "
") + .replace("\\@", "@") + ) + return "\n\n%s\n\n" % code + ret = self.pattern.sub(repl, "\n".join(lines)) return ret.split("\n") def md_filter_add_syntax_highlight(md): - md.preprocessors.add('highlight', CodeBlockPreprocessor(), "_begin") + md.preprocessors.add("highlight", CodeBlockPreprocessor(), "_begin") return True + + else: + def md_filter_add_syntax_highlight(md): return False @@ -276,7 +283,8 @@ else: try: from django.urls import include, path, re_path, register_converter # noqa except ImportError: - from django.conf.urls import include, url # noqa + from django.conf.urls import include, url # noqa + path = None register_converter = None re_path = url @@ -285,13 +293,13 @@ except ImportError: # `separators` argument to `json.dumps()` differs between 2.x and 3.x # See: https://bugs.python.org/issue22767 if six.PY3: - SHORT_SEPARATORS = (',', ':') - LONG_SEPARATORS = (', ', ': ') - INDENT_SEPARATORS = (',', ': ') + SHORT_SEPARATORS = (",", ":") + LONG_SEPARATORS = (", ", ": ") + INDENT_SEPARATORS = (",", ": ") else: - SHORT_SEPARATORS = (b',', b':') - LONG_SEPARATORS = (b', ', b': ') - INDENT_SEPARATORS = (b',', b': ') + SHORT_SEPARATORS = (b",", b":") + LONG_SEPARATORS = (b", ", b": ") + INDENT_SEPARATORS = (b",", b": ") class CustomValidatorMessage(object): @@ -303,7 +311,7 @@ class CustomValidatorMessage(object): """ def __init__(self, *args, **kwargs): - self.message = kwargs.pop('message', self.message) + self.message = kwargs.pop("message", self.message) super(CustomValidatorMessage, self).__init__(*args, **kwargs) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 30bfcc4e5..6999e6a79 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -23,14 +23,14 @@ def api_view(http_method_names=None): Decorator that converts a function-based view into an APIView subclass. Takes a list of allowed methods for the view as an argument. """ - http_method_names = ['GET'] if (http_method_names is None) else http_method_names + http_method_names = ["GET"] if (http_method_names is None) else http_method_names def decorator(func): WrappedAPIView = type( - six.PY3 and 'WrappedAPIView' or b'WrappedAPIView', + six.PY3 and "WrappedAPIView" or b"WrappedAPIView", (APIView,), - {'__doc__': func.__doc__} + {"__doc__": func.__doc__}, ) # Note, the above allows us to set the docstring. @@ -41,15 +41,20 @@ def api_view(http_method_names=None): # WrappedAPIView.__doc__ = func.doc <--- Not possible to do this # api_view applied without (method_names) - assert not(isinstance(http_method_names, types.FunctionType)), \ - '@api_view missing list of allowed HTTP methods' + assert not ( + isinstance(http_method_names, types.FunctionType) + ), "@api_view missing list of allowed HTTP methods" # api_view applied with eg. string instead of list of strings - assert isinstance(http_method_names, (list, tuple)), \ - '@api_view expected a list of strings, received %s' % type(http_method_names).__name__ + assert isinstance(http_method_names, (list, tuple)), ( + "@api_view expected a list of strings, received %s" + % type(http_method_names).__name__ + ) - allowed_methods = set(http_method_names) | {'options'} - WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] + allowed_methods = set(http_method_names) | {"options"} + WrappedAPIView.http_method_names = [ + method.lower() for method in allowed_methods + ] def handler(self, *args, **kwargs): return func(*args, **kwargs) @@ -60,23 +65,27 @@ def api_view(http_method_names=None): WrappedAPIView.__name__ = func.__name__ WrappedAPIView.__module__ = func.__module__ - WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes', - APIView.renderer_classes) + WrappedAPIView.renderer_classes = getattr( + func, "renderer_classes", APIView.renderer_classes + ) - WrappedAPIView.parser_classes = getattr(func, 'parser_classes', - APIView.parser_classes) + WrappedAPIView.parser_classes = getattr( + func, "parser_classes", APIView.parser_classes + ) - WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', - APIView.authentication_classes) + WrappedAPIView.authentication_classes = getattr( + func, "authentication_classes", APIView.authentication_classes + ) - WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', - APIView.throttle_classes) + WrappedAPIView.throttle_classes = getattr( + func, "throttle_classes", APIView.throttle_classes + ) - WrappedAPIView.permission_classes = getattr(func, 'permission_classes', - APIView.permission_classes) + WrappedAPIView.permission_classes = getattr( + func, "permission_classes", APIView.permission_classes + ) - WrappedAPIView.schema = getattr(func, 'schema', - APIView.schema) + WrappedAPIView.schema = getattr(func, "schema", APIView.schema) return WrappedAPIView.as_view() @@ -87,6 +96,7 @@ def renderer_classes(renderer_classes): def decorator(func): func.renderer_classes = renderer_classes return func + return decorator @@ -94,6 +104,7 @@ def parser_classes(parser_classes): def decorator(func): func.parser_classes = parser_classes return func + return decorator @@ -101,6 +112,7 @@ def authentication_classes(authentication_classes): def decorator(func): func.authentication_classes = authentication_classes return func + return decorator @@ -108,6 +120,7 @@ def throttle_classes(throttle_classes): def decorator(func): func.throttle_classes = throttle_classes return func + return decorator @@ -115,6 +128,7 @@ def permission_classes(permission_classes): def decorator(func): func.permission_classes = permission_classes return func + return decorator @@ -122,6 +136,7 @@ def schema(view_inspector): def decorator(func): func.schema = view_inspector return func + return decorator @@ -132,15 +147,13 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs): Set the `detail` boolean to determine if this action should apply to instance/detail requests or collection/list requests. """ - methods = ['get'] if (methods is None) else methods + methods = ["get"] if (methods is None) else methods methods = [method.lower() for method in methods] - assert detail is not None, ( - "@action() missing required argument: 'detail'" - ) + assert detail is not None, "@action() missing required argument: 'detail'" # name and suffix are mutually exclusive - if 'name' in kwargs and 'suffix' in kwargs: + if "name" in kwargs and "suffix" in kwargs: raise TypeError("`name` and `suffix` are mutually exclusive arguments.") def decorator(func): @@ -148,15 +161,16 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs): func.detail = detail func.url_path = url_path if url_path else func.__name__ - func.url_name = url_name if url_name else func.__name__.replace('_', '-') + func.url_name = url_name if url_name else func.__name__.replace("_", "-") func.kwargs = kwargs # Set descriptive arguments for viewsets - if 'name' not in kwargs and 'suffix' not in kwargs: - func.kwargs['name'] = pretty_name(func.__name__) - func.kwargs['description'] = func.__doc__ or None + if "name" not in kwargs and "suffix" not in kwargs: + func.kwargs["name"] = pretty_name(func.__name__) + func.kwargs["description"] = func.__doc__ or None return func + return decorator @@ -184,39 +198,42 @@ class MethodMapper(dict): self[method] = self.action.__name__ def _map(self, method, func): - assert method not in self, ( - "Method '%s' has already been mapped to '.%s'." % (method, self[method])) + assert method not in self, "Method '%s' has already been mapped to '.%s'." % ( + method, + self[method], + ) assert func.__name__ != self.action.__name__, ( "Method mapping does not behave like the property decorator. You " - "cannot use the same method name for each mapping declaration.") + "cannot use the same method name for each mapping declaration." + ) self[method] = func.__name__ return func def get(self, func): - return self._map('get', func) + return self._map("get", func) def post(self, func): - return self._map('post', func) + return self._map("post", func) def put(self, func): - return self._map('put', func) + return self._map("put", func) def patch(self, func): - return self._map('patch', func) + return self._map("patch", func) def delete(self, func): - return self._map('delete', func) + return self._map("delete", func) def head(self, func): - return self._map('head', func) + return self._map("head", func) def options(self, func): - return self._map('options', func) + return self._map("options", func) def trace(self, func): - return self._map('trace', func) + return self._map("trace", func) def detail_route(methods=None, **kwargs): @@ -226,14 +243,16 @@ def detail_route(methods=None, **kwargs): warnings.warn( "`detail_route` is deprecated and will be removed in 3.10 in favor of " "`action`, which accepts a `detail` bool. Use `@action(detail=True)` instead.", - RemovedInDRF310Warning, stacklevel=2 + RemovedInDRF310Warning, + stacklevel=2, ) def decorator(func): func = action(methods, detail=True, **kwargs)(func) - if 'url_name' not in kwargs: - func.url_name = func.url_path.replace('_', '-') + if "url_name" not in kwargs: + func.url_name = func.url_path.replace("_", "-") return func + return decorator @@ -244,12 +263,14 @@ def list_route(methods=None, **kwargs): warnings.warn( "`list_route` is deprecated and will be removed in 3.10 in favor of " "`action`, which accepts a `detail` bool. Use `@action(detail=False)` instead.", - RemovedInDRF310Warning, stacklevel=2 + RemovedInDRF310Warning, + stacklevel=2, ) def decorator(func): func = action(methods, detail=False, **kwargs)(func) - if 'url_name' not in kwargs: - func.url_name = func.url_path.replace('_', '-') + if "url_name" not in kwargs: + func.url_name = func.url_path.replace("_", "-") return func + return decorator diff --git a/rest_framework/documentation.py b/rest_framework/documentation.py index 3a78bb341..f86d91c93 100644 --- a/rest_framework/documentation.py +++ b/rest_framework/documentation.py @@ -1,18 +1,25 @@ from django.conf.urls import include, url from rest_framework.renderers import ( - CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer + CoreJSONRenderer, + DocumentationRenderer, + SchemaJSRenderer, ) from rest_framework.schemas import SchemaGenerator, get_schema_view from rest_framework.settings import api_settings def get_docs_view( - title=None, description=None, schema_url=None, public=True, - patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, - renderer_classes=None): + title=None, + description=None, + schema_url=None, + public=True, + patterns=None, + generator_class=SchemaGenerator, + authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, + permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, + renderer_classes=None, +): if renderer_classes is None: renderer_classes = [DocumentationRenderer, CoreJSONRenderer] @@ -31,10 +38,15 @@ def get_docs_view( def get_schemajs_view( - title=None, description=None, schema_url=None, public=True, - patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): + title=None, + description=None, + schema_url=None, + public=True, + patterns=None, + generator_class=SchemaGenerator, + authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, + permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, +): renderer_classes = [SchemaJSRenderer] return get_schema_view( @@ -51,11 +63,16 @@ def get_schemajs_view( def include_docs_urls( - title=None, description=None, schema_url=None, public=True, - patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, - renderer_classes=None): + title=None, + description=None, + schema_url=None, + public=True, + patterns=None, + generator_class=SchemaGenerator, + authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, + permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, + renderer_classes=None, +): docs_view = get_docs_view( title=title, description=description, @@ -78,7 +95,7 @@ def include_docs_urls( permission_classes=permission_classes, ) urls = [ - url(r'^$', docs_view, name='docs-index'), - url(r'^schema.js$', schema_js_view, name='schema-js') + url(r"^$", docs_view, name="docs-index"), + url(r"^schema.js$", schema_js_view, name="schema-js"), ] - return include((urls, 'api-docs'), namespace='api-docs') + return include((urls, "api-docs"), namespace="api-docs") diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index f79b16129..1dc3feca3 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -11,8 +11,7 @@ import math from django.http import JsonResponse from django.utils import six from django.utils.encoding import force_text -from django.utils.translation import ugettext_lazy as _ -from django.utils.translation import ungettext +from django.utils.translation import ugettext_lazy as _, ungettext from rest_framework import status from rest_framework.compat import unicode_to_repr @@ -25,23 +24,20 @@ def _get_error_details(data, default_code=None): lazy translation strings or strings into `ErrorDetail`. """ if isinstance(data, list): - ret = [ - _get_error_details(item, default_code) for item in data - ] + ret = [_get_error_details(item, default_code) for item in data] if isinstance(data, ReturnList): return ReturnList(ret, serializer=data.serializer) return ret elif isinstance(data, dict): ret = { - key: _get_error_details(value, default_code) - for key, value in data.items() + key: _get_error_details(value, default_code) for key, value in data.items() } if isinstance(data, ReturnDict): return ReturnDict(ret, serializer=data.serializer) return ret text = force_text(data) - code = getattr(data, 'code', default_code) + code = getattr(data, "code", default_code) return ErrorDetail(text, code) @@ -58,16 +54,14 @@ def _get_full_details(detail): return [_get_full_details(item) for item in detail] elif isinstance(detail, dict): return {key: _get_full_details(value) for key, value in detail.items()} - return { - 'message': detail, - 'code': detail.code - } + return {"message": detail, "code": detail.code} class ErrorDetail(six.text_type): """ A string-like object that can additionally have a code. """ + code = None def __new__(cls, string, code=None): @@ -86,10 +80,9 @@ class ErrorDetail(six.text_type): return not self.__eq__(other) def __repr__(self): - return unicode_to_repr('ErrorDetail(string=%r, code=%r)' % ( - six.text_type(self), - self.code, - )) + return unicode_to_repr( + "ErrorDetail(string=%r, code=%r)" % (six.text_type(self), self.code) + ) def __hash__(self): return hash(str(self)) @@ -100,9 +93,10 @@ class APIException(Exception): Base class for REST framework exceptions. Subclasses should provide `.status_code` and `.default_detail` properties. """ + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - default_detail = _('A server error occurred.') - default_code = 'error' + default_detail = _("A server error occurred.") + default_code = "error" def __init__(self, detail=None, code=None): if detail is None: @@ -139,10 +133,11 @@ class APIException(Exception): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - default_detail = _('Invalid input.') - default_code = 'invalid' + default_detail = _("Invalid input.") + default_code = "invalid" def __init__(self, detail=None, code=None): if detail is None: @@ -160,38 +155,38 @@ class ValidationError(APIException): class ParseError(APIException): status_code = status.HTTP_400_BAD_REQUEST - default_detail = _('Malformed request.') - default_code = 'parse_error' + default_detail = _("Malformed request.") + default_code = "parse_error" class AuthenticationFailed(APIException): status_code = status.HTTP_401_UNAUTHORIZED - default_detail = _('Incorrect authentication credentials.') - default_code = 'authentication_failed' + default_detail = _("Incorrect authentication credentials.") + default_code = "authentication_failed" class NotAuthenticated(APIException): status_code = status.HTTP_401_UNAUTHORIZED - default_detail = _('Authentication credentials were not provided.') - default_code = 'not_authenticated' + default_detail = _("Authentication credentials were not provided.") + default_code = "not_authenticated" class PermissionDenied(APIException): status_code = status.HTTP_403_FORBIDDEN - default_detail = _('You do not have permission to perform this action.') - default_code = 'permission_denied' + default_detail = _("You do not have permission to perform this action.") + default_code = "permission_denied" class NotFound(APIException): status_code = status.HTTP_404_NOT_FOUND - default_detail = _('Not found.') - default_code = 'not_found' + default_detail = _("Not found.") + default_code = "not_found" class MethodNotAllowed(APIException): status_code = status.HTTP_405_METHOD_NOT_ALLOWED default_detail = _('Method "{method}" not allowed.') - default_code = 'method_not_allowed' + default_code = "method_not_allowed" def __init__(self, method, detail=None, code=None): if detail is None: @@ -201,8 +196,8 @@ class MethodNotAllowed(APIException): class NotAcceptable(APIException): status_code = status.HTTP_406_NOT_ACCEPTABLE - default_detail = _('Could not satisfy the request Accept header.') - default_code = 'not_acceptable' + default_detail = _("Could not satisfy the request Accept header.") + default_code = "not_acceptable" def __init__(self, detail=None, code=None, available_renderers=None): self.available_renderers = available_renderers @@ -212,7 +207,7 @@ class NotAcceptable(APIException): class UnsupportedMediaType(APIException): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE default_detail = _('Unsupported media type "{media_type}" in request.') - default_code = 'unsupported_media_type' + default_code = "unsupported_media_type" def __init__(self, media_type, detail=None, code=None): if detail is None: @@ -222,21 +217,28 @@ class UnsupportedMediaType(APIException): class Throttled(APIException): status_code = status.HTTP_429_TOO_MANY_REQUESTS - default_detail = _('Request was throttled.') - extra_detail_singular = 'Expected available in {wait} second.' - extra_detail_plural = 'Expected available in {wait} seconds.' - default_code = 'throttled' + default_detail = _("Request was throttled.") + extra_detail_singular = "Expected available in {wait} second." + extra_detail_plural = "Expected available in {wait} seconds." + default_code = "throttled" def __init__(self, wait=None, detail=None, code=None): if detail is None: detail = force_text(self.default_detail) if wait is not None: wait = math.ceil(wait) - detail = ' '.join(( - detail, - force_text(ungettext(self.extra_detail_singular.format(wait=wait), - self.extra_detail_plural.format(wait=wait), - wait)))) + detail = " ".join( + ( + detail, + force_text( + ungettext( + self.extra_detail_singular.format(wait=wait), + self.extra_detail_plural.format(wait=wait), + wait, + ) + ), + ) + ) self.wait = wait super(Throttled, self).__init__(detail, code) @@ -245,9 +247,7 @@ def server_error(request, *args, **kwargs): """ Generic 500 error handler. """ - data = { - 'error': 'Server Error (500)' - } + data = {"error": "Server Error (500)"} return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR) @@ -255,7 +255,5 @@ def bad_request(request, exception, *args, **kwargs): """ Generic 400 error handler. """ - data = { - 'error': 'Bad Request (400)' - } + data = {"error": "Bad Request (400)"} return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c8f65db0e..998d4781f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -10,16 +10,26 @@ import uuid from collections import OrderedDict from django.conf import settings -from django.core.exceptions import ObjectDoesNotExist -from django.core.exceptions import ValidationError as DjangoValidationError -from django.core.validators import ( - EmailValidator, RegexValidator, URLValidator, ip_address_validators +from django.core.exceptions import ( + ObjectDoesNotExist, + ValidationError as DjangoValidationError, +) +from django.core.validators import ( + EmailValidator, + RegexValidator, + URLValidator, + ip_address_validators, +) +from django.forms import ( + FilePathField as DjangoFilePathField, + ImageField as DjangoImageField, ) -from django.forms import FilePathField as DjangoFilePathField -from django.forms import ImageField as DjangoImageField from django.utils import six, timezone from django.utils.dateparse import ( - parse_date, parse_datetime, parse_duration, parse_time + parse_date, + parse_datetime, + parse_duration, + parse_time, ) from django.utils.duration import duration_string from django.utils.encoding import is_protected_type, smart_text @@ -32,9 +42,14 @@ from pytz.exceptions import InvalidTimeError from rest_framework import ISO_8601 from rest_framework.compat import ( - Mapping, MaxLengthValidator, MaxValueValidator, MinLengthValidator, - MinValueValidator, ProhibitNullCharactersValidator, unicode_repr, - unicode_to_repr + Mapping, + MaxLengthValidator, + MaxValueValidator, + MinLengthValidator, + MinValueValidator, + ProhibitNullCharactersValidator, + unicode_repr, + unicode_to_repr, ) from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.settings import api_settings @@ -48,27 +63,35 @@ class empty: It is required because `None` may be a valid input or output value. """ + pass if six.PY3: + def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ - if not (inspect.isfunction(obj) or inspect.ismethod(obj) or isinstance(obj, functools.partial)): + if not ( + inspect.isfunction(obj) + or inspect.ismethod(obj) + or isinstance(obj, functools.partial) + ): return False sig = inspect.signature(obj) params = sig.parameters.values() return all( - param.kind == param.VAR_POSITIONAL or - param.kind == param.VAR_KEYWORD or - param.default != param.empty + param.kind == param.VAR_POSITIONAL + or param.kind == param.VAR_KEYWORD + or param.default != param.empty for param in params ) + else: + def is_simple_callable(obj): function = inspect.isfunction(obj) method = inspect.ismethod(obj) @@ -108,7 +131,11 @@ def get_attribute(instance, attrs): # If we raised an Attribute or KeyError here it'd get treated # as an omitted field in `Field.get_attribute()`. Instead we # raise a ValueError to ensure the exception is not masked. - raise ValueError('Exception raised in callable attribute "{0}"; original exception was: {1}'.format(attr, exc)) + raise ValueError( + 'Exception raised in callable attribute "{0}"; original exception was: {1}'.format( + attr, exc + ) + ) return instance @@ -185,6 +212,7 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None): """ Helper function for options and option groups in templates. """ + class StartOptionGroup(object): start_option_group = True end_option_group = False @@ -225,7 +253,7 @@ def iter_options(grouped_choices, cutoff=None, cutoff_text=None): if cutoff and count >= cutoff and cutoff_text: cutoff_text = cutoff_text.format(count=cutoff) - yield Option(value='n/a', display_text=cutoff_text, disabled=True) + yield Option(value="n/a", display_text=cutoff_text, disabled=True) def get_error_detail(exc_info): @@ -233,21 +261,27 @@ def get_error_detail(exc_info): Given a Django ValidationError, return a list of ErrorDetail, with the `code` populated. """ - code = getattr(exc_info, 'code', None) or 'invalid' + code = getattr(exc_info, "code", None) or "invalid" try: error_dict = exc_info.error_dict except AttributeError: return [ - ErrorDetail(error.message % (error.params or ()), - code=error.code if error.code else code) - for error in exc_info.error_list] + ErrorDetail( + error.message % (error.params or ()), + code=error.code if error.code else code, + ) + for error in exc_info.error_list + ] return { k: [ - ErrorDetail(error.message % (error.params or ()), - code=error.code if error.code else code) + ErrorDetail( + error.message % (error.params or ()), + code=error.code if error.code else code, + ) for error in errors - ] for k, errors in error_dict.items() + ] + for k, errors in error_dict.items() } @@ -257,12 +291,17 @@ class CreateOnlyDefault(object): for create operations, but that do not return any value for update operations. """ + def __init__(self, default): self.default = default def set_context(self, serializer_field): self.is_update = serializer_field.parent.instance is not None - if callable(self.default) and hasattr(self.default, 'set_context') and not self.is_update: + if ( + callable(self.default) + and hasattr(self.default, "set_context") + and not self.is_update + ): self.default.set_context(serializer_field) def __call__(self): @@ -274,34 +313,34 @@ class CreateOnlyDefault(object): def __repr__(self): return unicode_to_repr( - '%s(%s)' % (self.__class__.__name__, unicode_repr(self.default)) + "%s(%s)" % (self.__class__.__name__, unicode_repr(self.default)) ) class CurrentUserDefault(object): def set_context(self, serializer_field): - self.user = serializer_field.context['request'].user + self.user = serializer_field.context["request"].user def __call__(self): return self.user def __repr__(self): - return unicode_to_repr('%s()' % self.__class__.__name__) + return unicode_to_repr("%s()" % self.__class__.__name__) class SkipField(Exception): pass -REGEX_TYPE = type(re.compile('')) +REGEX_TYPE = type(re.compile("")) -NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`' -NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`' -NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`' -USE_READONLYFIELD = 'Field(read_only=True) should be ReadOnlyField' +NOT_READ_ONLY_WRITE_ONLY = "May not set both `read_only` and `write_only`" +NOT_READ_ONLY_REQUIRED = "May not set both `read_only` and `required`" +NOT_REQUIRED_DEFAULT = "May not set both `required` and `default`" +USE_READONLYFIELD = "Field(read_only=True) should be ReadOnlyField" MISSING_ERROR_MESSAGE = ( - 'ValidationError raised by `{class_name}`, but error key `{key}` does ' - 'not exist in the `error_messages` dictionary.' + "ValidationError raised by `{class_name}`, but error key `{key}` does " + "not exist in the `error_messages` dictionary." ) @@ -309,17 +348,28 @@ class Field(object): _creation_counter = 0 default_error_messages = { - 'required': _('This field is required.'), - 'null': _('This field may not be null.') + "required": _("This field is required."), + "null": _("This field may not be null."), } default_validators = [] default_empty_html = empty initial = None - def __init__(self, read_only=False, write_only=False, - required=None, default=empty, initial=empty, source=None, - label=None, help_text=None, style=None, - error_messages=None, validators=None, allow_null=False): + def __init__( + self, + read_only=False, + write_only=False, + required=None, + default=empty, + initial=empty, + source=None, + label=None, + help_text=None, + style=None, + error_messages=None, + validators=None, + allow_null=False, + ): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -358,7 +408,7 @@ class Field(object): # Collect default error message from self and parent classes messages = {} for cls in reversed(self.__class__.__mro__): - messages.update(getattr(cls, 'default_error_messages', {})) + messages.update(getattr(cls, "default_error_messages", {})) messages.update(error_messages or {}) self.error_messages = messages @@ -374,8 +424,8 @@ class Field(object): assert self.source != field_name, ( "It is redundant to specify `source='%s'` on field '%s' in " "serializer '%s', because it is the same as the field name. " - "Remove the `source` keyword argument." % - (field_name, self.__class__.__name__, parent.__class__.__name__) + "Remove the `source` keyword argument." + % (field_name, self.__class__.__name__, parent.__class__.__name__) ) self.field_name = field_name @@ -383,7 +433,7 @@ class Field(object): # `self.label` should default to being based on the field name. if self.label is None: - self.label = field_name.replace('_', ' ').capitalize() + self.label = field_name.replace("_", " ").capitalize() # self.source should default to being the same as the field name. if self.source is None: @@ -391,16 +441,16 @@ class Field(object): # self.source_attrs is a list of attributes that need to be looked up # when serializing the instance, or populating the validated data. - if self.source == '*': + if self.source == "*": self.source_attrs = [] else: - self.source_attrs = self.source.split('.') + self.source_attrs = self.source.split(".") # .validators is a lazily loaded property, that gets its default # value from `get_validators`. @property def validators(self): - if not hasattr(self, '_validators'): + if not hasattr(self, "_validators"): self._validators = self.get_validators() return self._validators @@ -429,18 +479,18 @@ class Field(object): # HTML forms will represent empty fields as '', and cannot # represent None or False values directly. if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): + if getattr(self.root, "partial", False): return empty return self.default_empty_html ret = dictionary[self.field_name] - if ret == '' and self.allow_null: + if ret == "" and self.allow_null: # If the field is blank, and null is a valid value then # determine if we should use null instead. - return '' if getattr(self, 'allow_blank', False) else None - elif ret == '' and not self.required: + return "" if getattr(self, "allow_blank", False) else None + elif ret == "" and not self.required: # If the field is blank, and emptiness is valid then # determine if we should use emptiness instead. - return '' if getattr(self, 'allow_blank', False) else empty + return "" if getattr(self, "allow_blank", False) else empty return ret return dictionary.get(self.field_name, empty) @@ -459,16 +509,16 @@ class Field(object): if not self.required: raise SkipField() msg = ( - 'Got {exc_type} when attempting to get a value for field ' - '`{field}` on serializer `{serializer}`.\nThe serializer ' - 'field might be named incorrectly and not match ' - 'any attribute or key on the `{instance}` instance.\n' - 'Original exception text was: {exc}.'.format( + "Got {exc_type} when attempting to get a value for field " + "`{field}` on serializer `{serializer}`.\nThe serializer " + "field might be named incorrectly and not match " + "any attribute or key on the `{instance}` instance.\n" + "Original exception text was: {exc}.".format( exc_type=type(exc).__name__, field=self.field_name, serializer=self.parent.__class__.__name__, instance=instance.__class__.__name__, - exc=exc + exc=exc, ) ) raise type(exc)(msg) @@ -482,11 +532,11 @@ class Field(object): raise `SkipField`, indicating that no value should be set in the validated data for this field. """ - if self.default is empty or getattr(self.root, 'partial', False): + if self.default is empty or getattr(self.root, "partial", False): # No default, or this is a partial update. raise SkipField() if callable(self.default): - if hasattr(self.default, 'set_context'): + if hasattr(self.default, "set_context"): self.default.set_context(self) return self.default() return self.default @@ -506,15 +556,15 @@ class Field(object): return (True, self.get_default()) if data is empty: - if getattr(self.root, 'partial', False): + if getattr(self.root, "partial", False): raise SkipField() if self.required: - self.fail('required') + self.fail("required") return (True, self.get_default()) if data is None: if not self.allow_null: - self.fail('null') + self.fail("null") return (True, None) return (False, data) @@ -543,7 +593,7 @@ class Field(object): """ errors = [] for validator in self.validators: - if hasattr(validator, 'set_context'): + if hasattr(validator, "set_context"): validator.set_context(self) try: @@ -565,7 +615,7 @@ class Field(object): Transform the *incoming* primitive data into a native value. """ raise NotImplementedError( - '{cls}.to_internal_value() must be implemented.'.format( + "{cls}.to_internal_value() must be implemented.".format( cls=self.__class__.__name__ ) ) @@ -575,11 +625,10 @@ class Field(object): Transform the *outgoing* native value into primitive data. """ raise NotImplementedError( - '{cls}.to_representation() must be implemented for field ' - '{field_name}. If you do not need to support write operations ' - 'you probably want to subclass `ReadOnlyField` instead.'.format( - cls=self.__class__.__name__, - field_name=self.field_name, + "{cls}.to_representation() must be implemented for field " + "{field_name}. If you do not need to support write operations " + "you probably want to subclass `ReadOnlyField` instead.".format( + cls=self.__class__.__name__, field_name=self.field_name ) ) @@ -611,7 +660,7 @@ class Field(object): """ Returns the context as passed to the root serializer on initialization. """ - return getattr(self.root, '_context', {}) + return getattr(self.root, "_context", {}) def __new__(cls, *args, **kwargs): """ @@ -636,7 +685,9 @@ class Field(object): for item in self._args ] kwargs = { - key: (copy.deepcopy(value) if (key not in ('validators', 'regex')) else value) + key: ( + copy.deepcopy(value) if (key not in ("validators", "regex")) else value + ) for key, value in self._kwargs.items() } return self.__class__(*args, **kwargs) @@ -652,29 +703,47 @@ class Field(object): # Boolean types... + class BooleanField(Field): - default_error_messages = { - 'invalid': _('Must be a valid boolean.') - } + default_error_messages = {"invalid": _("Must be a valid boolean.")} default_empty_html = False initial = False TRUE_VALUES = { - 't', 'T', - 'y', 'Y', 'yes', 'YES', - 'true', 'True', 'TRUE', - 'on', 'On', 'ON', - '1', 1, - True + "t", + "T", + "y", + "Y", + "yes", + "YES", + "true", + "True", + "TRUE", + "on", + "On", + "ON", + "1", + 1, + True, } FALSE_VALUES = { - 'f', 'F', - 'n', 'N', 'no', 'NO', - 'false', 'False', 'FALSE', - 'off', 'Off', 'OFF', - '0', 0, 0.0, - False + "f", + "F", + "n", + "N", + "no", + "NO", + "false", + "False", + "FALSE", + "off", + "Off", + "OFF", + "0", + 0, + 0.0, + False, } - NULL_VALUES = {'null', 'Null', 'NULL', '', None} + NULL_VALUES = {"null", "Null", "NULL", "", None} def to_internal_value(self, data): try: @@ -686,7 +755,7 @@ class BooleanField(Field): return None except TypeError: # Input is an unhashable type pass - self.fail('invalid', input=data) + self.fail("invalid", input=data) def to_representation(self, value): if value in self.TRUE_VALUES: @@ -699,31 +768,48 @@ class BooleanField(Field): class NullBooleanField(Field): - default_error_messages = { - 'invalid': _('Must be a valid boolean.') - } + default_error_messages = {"invalid": _("Must be a valid boolean.")} initial = None TRUE_VALUES = { - 't', 'T', - 'y', 'Y', 'yes', 'YES', - 'true', 'True', 'TRUE', - 'on', 'On', 'ON', - '1', 1, - True + "t", + "T", + "y", + "Y", + "yes", + "YES", + "true", + "True", + "TRUE", + "on", + "On", + "ON", + "1", + 1, + True, } FALSE_VALUES = { - 'f', 'F', - 'n', 'N', 'no', 'NO', - 'false', 'False', 'FALSE', - 'off', 'Off', 'OFF', - '0', 0, 0.0, - False + "f", + "F", + "n", + "N", + "no", + "NO", + "false", + "False", + "FALSE", + "off", + "Off", + "OFF", + "0", + 0, + 0.0, + False, } - NULL_VALUES = {'null', 'Null', 'NULL', '', None} + NULL_VALUES = {"null", "Null", "NULL", "", None} def __init__(self, **kwargs): - assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' - kwargs['allow_null'] = True + assert "allow_null" not in kwargs, "`allow_null` is not a valid option." + kwargs["allow_null"] = True super(NullBooleanField, self).__init__(**kwargs) def to_internal_value(self, data): @@ -736,7 +822,7 @@ class NullBooleanField(Field): return None except TypeError: # Input is an unhashable type pass - self.fail('invalid', input=data) + self.fail("invalid", input=data) def to_representation(self, value): if value in self.NULL_VALUES: @@ -750,33 +836,32 @@ class NullBooleanField(Field): # String types... + class CharField(Field): default_error_messages = { - 'invalid': _('Not a valid string.'), - 'blank': _('This field may not be blank.'), - 'max_length': _('Ensure this field has no more than {max_length} characters.'), - 'min_length': _('Ensure this field has at least {min_length} characters.'), + "invalid": _("Not a valid string."), + "blank": _("This field may not be blank."), + "max_length": _("Ensure this field has no more than {max_length} characters."), + "min_length": _("Ensure this field has at least {min_length} characters."), } - initial = '' + initial = "" def __init__(self, **kwargs): - self.allow_blank = kwargs.pop('allow_blank', False) - self.trim_whitespace = kwargs.pop('trim_whitespace', True) - self.max_length = kwargs.pop('max_length', None) - self.min_length = kwargs.pop('min_length', None) + self.allow_blank = kwargs.pop("allow_blank", False) + self.trim_whitespace = kwargs.pop("trim_whitespace", True) + self.max_length = kwargs.pop("max_length", None) + self.min_length = kwargs.pop("min_length", None) super(CharField, self).__init__(**kwargs) if self.max_length is not None: - message = lazy( - self.error_messages['max_length'].format, - six.text_type)(max_length=self.max_length) - self.validators.append( - MaxLengthValidator(self.max_length, message=message)) + message = lazy(self.error_messages["max_length"].format, six.text_type)( + max_length=self.max_length + ) + self.validators.append(MaxLengthValidator(self.max_length, message=message)) if self.min_length is not None: - message = lazy( - self.error_messages['min_length'].format, - six.text_type)(min_length=self.min_length) - self.validators.append( - MinLengthValidator(self.min_length, message=message)) + message = lazy(self.error_messages["min_length"].format, six.text_type)( + min_length=self.min_length + ) + self.validators.append(MinLengthValidator(self.min_length, message=message)) # ProhibitNullCharactersValidator is None on Django < 2.0 if ProhibitNullCharactersValidator is not None: @@ -786,18 +871,20 @@ class CharField(Field): # Test for the empty string here so that it does not get validated, # and so that subclasses do not need to handle it explicitly # inside the `to_internal_value()` method. - if data == '' or (self.trim_whitespace and six.text_type(data).strip() == ''): + if data == "" or (self.trim_whitespace and six.text_type(data).strip() == ""): if not self.allow_blank: - self.fail('blank') - return '' + self.fail("blank") + return "" return super(CharField, self).run_validation(data) def to_internal_value(self, data): # We're lenient with allowing basic numerics to be coerced into strings, # but other types should fail. Eg. unclear if booleans should represent as `true` or `True`, # and composites such as lists are likely user error. - if isinstance(data, bool) or not isinstance(data, six.string_types + six.integer_types + (float,)): - self.fail('invalid') + if isinstance(data, bool) or not isinstance( + data, six.string_types + six.integer_types + (float,) + ): + self.fail("invalid") value = six.text_type(data) return value.strip() if self.trim_whitespace else value @@ -806,66 +893,69 @@ class CharField(Field): class EmailField(CharField): - default_error_messages = { - 'invalid': _('Enter a valid email address.') - } + default_error_messages = {"invalid": _("Enter a valid email address.")} def __init__(self, **kwargs): super(EmailField, self).__init__(**kwargs) - validator = EmailValidator(message=self.error_messages['invalid']) + validator = EmailValidator(message=self.error_messages["invalid"]) self.validators.append(validator) class RegexField(CharField): default_error_messages = { - 'invalid': _('This value does not match the required pattern.') + "invalid": _("This value does not match the required pattern.") } def __init__(self, regex, **kwargs): super(RegexField, self).__init__(**kwargs) - validator = RegexValidator(regex, message=self.error_messages['invalid']) + validator = RegexValidator(regex, message=self.error_messages["invalid"]) self.validators.append(validator) class SlugField(CharField): default_error_messages = { - 'invalid': _('Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.'), - 'invalid_unicode': _('Enter a valid "slug" consisting of Unicode letters, numbers, underscores, or hyphens.') + "invalid": _( + 'Enter a valid "slug" consisting of letters, numbers, underscores or hyphens.' + ), + "invalid_unicode": _( + 'Enter a valid "slug" consisting of Unicode letters, numbers, underscores, or hyphens.' + ), } def __init__(self, allow_unicode=False, **kwargs): super(SlugField, self).__init__(**kwargs) self.allow_unicode = allow_unicode if self.allow_unicode: - validator = RegexValidator(re.compile(r'^[-\w]+\Z', re.UNICODE), message=self.error_messages['invalid_unicode']) + validator = RegexValidator( + re.compile(r"^[-\w]+\Z", re.UNICODE), + message=self.error_messages["invalid_unicode"], + ) else: - validator = RegexValidator(re.compile(r'^[-a-zA-Z0-9_]+$'), message=self.error_messages['invalid']) + validator = RegexValidator( + re.compile(r"^[-a-zA-Z0-9_]+$"), message=self.error_messages["invalid"] + ) self.validators.append(validator) class URLField(CharField): - default_error_messages = { - 'invalid': _('Enter a valid URL.') - } + default_error_messages = {"invalid": _("Enter a valid URL.")} def __init__(self, **kwargs): super(URLField, self).__init__(**kwargs) - validator = URLValidator(message=self.error_messages['invalid']) + validator = URLValidator(message=self.error_messages["invalid"]) self.validators.append(validator) class UUIDField(Field): - valid_formats = ('hex_verbose', 'hex', 'int', 'urn') + valid_formats = ("hex_verbose", "hex", "int", "urn") - default_error_messages = { - 'invalid': _('Must be a valid UUID.'), - } + default_error_messages = {"invalid": _("Must be a valid UUID.")} def __init__(self, **kwargs): - self.uuid_format = kwargs.pop('format', 'hex_verbose') + self.uuid_format = kwargs.pop("format", "hex_verbose") if self.uuid_format not in self.valid_formats: raise ValueError( - 'Invalid format for uuid representation. ' + "Invalid format for uuid representation. " 'Must be one of "{0}"'.format('", "'.join(self.valid_formats)) ) super(UUIDField, self).__init__(**kwargs) @@ -878,13 +968,13 @@ class UUIDField(Field): elif isinstance(data, six.string_types): return uuid.UUID(hex=data) else: - self.fail('invalid', value=data) + self.fail("invalid", value=data) except (ValueError): - self.fail('invalid', value=data) + self.fail("invalid", value=data) return data def to_representation(self, value): - if self.uuid_format == 'hex_verbose': + if self.uuid_format == "hex_verbose": return str(value) else: return getattr(value, self.uuid_format) @@ -893,68 +983,65 @@ class UUIDField(Field): class IPAddressField(CharField): """Support both IPAddressField and GenericIPAddressField""" - default_error_messages = { - 'invalid': _('Enter a valid IPv4 or IPv6 address.'), - } + default_error_messages = {"invalid": _("Enter a valid IPv4 or IPv6 address.")} - def __init__(self, protocol='both', **kwargs): + def __init__(self, protocol="both", **kwargs): self.protocol = protocol.lower() - self.unpack_ipv4 = (self.protocol == 'both') + self.unpack_ipv4 = self.protocol == "both" super(IPAddressField, self).__init__(**kwargs) validators, error_message = ip_address_validators(protocol, self.unpack_ipv4) self.validators.extend(validators) def to_internal_value(self, data): if not isinstance(data, six.string_types): - self.fail('invalid', value=data) + self.fail("invalid", value=data) - if ':' in data: + if ":" in data: try: - if self.protocol in ('both', 'ipv6'): + if self.protocol in ("both", "ipv6"): return clean_ipv6_address(data, self.unpack_ipv4) except DjangoValidationError: - self.fail('invalid', value=data) + self.fail("invalid", value=data) return super(IPAddressField, self).to_internal_value(data) # Number types... + class IntegerField(Field): default_error_messages = { - 'invalid': _('A valid integer is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_string_length': _('String value too large.') + "invalid": _("A valid integer is required."), + "max_value": _("Ensure this value is less than or equal to {max_value}."), + "min_value": _("Ensure this value is greater than or equal to {min_value}."), + "max_string_length": _("String value too large."), } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - re_decimal = re.compile(r'\.0*\s*$') # allow e.g. '1.0' as an int, but not '1.2' + re_decimal = re.compile(r"\.0*\s*$") # allow e.g. '1.0' as an int, but not '1.2' def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value', None) - self.min_value = kwargs.pop('min_value', None) + self.max_value = kwargs.pop("max_value", None) + self.min_value = kwargs.pop("min_value", None) super(IntegerField, self).__init__(**kwargs) if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) + message = lazy(self.error_messages["max_value"].format, six.text_type)( + max_value=self.max_value + ) + self.validators.append(MaxValueValidator(self.max_value, message=message)) if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) + message = lazy(self.error_messages["min_value"].format, six.text_type)( + min_value=self.min_value + ) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: - self.fail('max_string_length') + self.fail("max_string_length") try: - data = int(self.re_decimal.sub('', str(data))) + data = int(self.re_decimal.sub("", str(data))) except (ValueError, TypeError): - self.fail('invalid') + self.fail("invalid") return data def to_representation(self, value): @@ -963,39 +1050,37 @@ class IntegerField(Field): class FloatField(Field): default_error_messages = { - 'invalid': _('A valid number is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_string_length': _('String value too large.') + "invalid": _("A valid number is required."), + "max_value": _("Ensure this value is less than or equal to {max_value}."), + "min_value": _("Ensure this value is greater than or equal to {min_value}."), + "max_string_length": _("String value too large."), } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value', None) - self.min_value = kwargs.pop('min_value', None) + self.max_value = kwargs.pop("max_value", None) + self.min_value = kwargs.pop("min_value", None) super(FloatField, self).__init__(**kwargs) if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) + message = lazy(self.error_messages["max_value"].format, six.text_type)( + max_value=self.max_value + ) + self.validators.append(MaxValueValidator(self.max_value, message=message)) if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) + message = lazy(self.error_messages["min_value"].format, six.text_type)( + min_value=self.min_value + ) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, data): if isinstance(data, six.text_type) and len(data) > self.MAX_STRING_LENGTH: - self.fail('max_string_length') + self.fail("max_string_length") try: return float(data) except (TypeError, ValueError): - self.fail('invalid') + self.fail("invalid") def to_representation(self, value): return float(value) @@ -1003,18 +1088,33 @@ class FloatField(Field): class DecimalField(Field): default_error_messages = { - 'invalid': _('A valid number is required.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), - 'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'), - 'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'), - 'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.'), - 'max_string_length': _('String value too large.') + "invalid": _("A valid number is required."), + "max_value": _("Ensure this value is less than or equal to {max_value}."), + "min_value": _("Ensure this value is greater than or equal to {min_value}."), + "max_digits": _( + "Ensure that there are no more than {max_digits} digits in total." + ), + "max_decimal_places": _( + "Ensure that there are no more than {max_decimal_places} decimal places." + ), + "max_whole_digits": _( + "Ensure that there are no more than {max_whole_digits} digits before the decimal point." + ), + "max_string_length": _("String value too large."), } MAX_STRING_LENGTH = 1000 # Guard against malicious string inputs. - def __init__(self, max_digits, decimal_places, coerce_to_string=None, max_value=None, min_value=None, - localize=False, rounding=None, **kwargs): + def __init__( + self, + max_digits, + decimal_places, + coerce_to_string=None, + max_value=None, + min_value=None, + localize=False, + rounding=None, + **kwargs + ): self.max_digits = max_digits self.decimal_places = decimal_places self.localize = localize @@ -1034,22 +1134,24 @@ class DecimalField(Field): super(DecimalField, self).__init__(**kwargs) if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) + message = lazy(self.error_messages["max_value"].format, six.text_type)( + max_value=self.max_value + ) + self.validators.append(MaxValueValidator(self.max_value, message=message)) if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) + message = lazy(self.error_messages["min_value"].format, six.text_type)( + min_value=self.min_value + ) + self.validators.append(MinValueValidator(self.min_value, message=message)) if rounding is not None: - valid_roundings = [v for k, v in vars(decimal).items() if k.startswith('ROUND_')] + valid_roundings = [ + v for k, v in vars(decimal).items() if k.startswith("ROUND_") + ] assert rounding in valid_roundings, ( - 'Invalid rounding option %s. Valid values for rounding are: %s' % (rounding, valid_roundings)) + "Invalid rounding option %s. Valid values for rounding are: %s" + % (rounding, valid_roundings) + ) self.rounding = rounding def to_internal_value(self, data): @@ -1064,21 +1166,21 @@ class DecimalField(Field): data = sanitize_separators(data) if len(data) > self.MAX_STRING_LENGTH: - self.fail('max_string_length') + self.fail("max_string_length") try: value = decimal.Decimal(data) except decimal.DecimalException: - self.fail('invalid') + self.fail("invalid") # Check for NaN. It is the only value that isn't equal to itself, # so we can use this to identify NaN values. if value != value: - self.fail('invalid') + self.fail("invalid") # Check for infinity and negative infinity. - if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')): - self.fail('invalid') + if value in (decimal.Decimal("Inf"), decimal.Decimal("-Inf")): + self.fail("invalid") return self.quantize(self.validate_precision(value)) @@ -1109,16 +1211,18 @@ class DecimalField(Field): decimal_places = total_digits if self.max_digits is not None and total_digits > self.max_digits: - self.fail('max_digits', max_digits=self.max_digits) + self.fail("max_digits", max_digits=self.max_digits) if self.decimal_places is not None and decimal_places > self.decimal_places: - self.fail('max_decimal_places', max_decimal_places=self.decimal_places) + self.fail("max_decimal_places", max_decimal_places=self.decimal_places) if self.max_whole_digits is not None and whole_digits > self.max_whole_digits: - self.fail('max_whole_digits', max_whole_digits=self.max_whole_digits) + self.fail("max_whole_digits", max_whole_digits=self.max_whole_digits) return value def to_representation(self, value): - coerce_to_string = getattr(self, 'coerce_to_string', api_settings.COERCE_DECIMAL_TO_STRING) + coerce_to_string = getattr( + self, "coerce_to_string", api_settings.COERCE_DECIMAL_TO_STRING + ) if not isinstance(value, decimal.Decimal): value = decimal.Decimal(six.text_type(value).strip()) @@ -1130,7 +1234,7 @@ class DecimalField(Field): if self.localize: return localize_input(quantized) - return '{0:f}'.format(quantized) + return "{0:f}".format(quantized) def quantize(self, value): """ @@ -1143,24 +1247,29 @@ class DecimalField(Field): if self.max_digits is not None: context.prec = self.max_digits return value.quantize( - decimal.Decimal('.1') ** self.decimal_places, + decimal.Decimal(".1") ** self.decimal_places, rounding=self.rounding, - context=context + context=context, ) # Date & time fields... + class DateTimeField(Field): default_error_messages = { - 'invalid': _('Datetime has wrong format. Use one of these formats instead: {format}.'), - 'date': _('Expected a datetime but got a date.'), - 'make_aware': _('Invalid datetime for the timezone "{timezone}".'), - 'overflow': _('Datetime value out of range.') + "invalid": _( + "Datetime has wrong format. Use one of these formats instead: {format}." + ), + "date": _("Expected a datetime but got a date."), + "make_aware": _('Invalid datetime for the timezone "{timezone}".'), + "overflow": _("Datetime value out of range."), } datetime_parser = datetime.datetime.strptime - def __init__(self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs): + def __init__( + self, format=empty, input_formats=None, default_timezone=None, *args, **kwargs + ): if format is not empty: self.format = format if input_formats is not None: @@ -1174,18 +1283,18 @@ class DateTimeField(Field): When `self.default_timezone` is `None`, always return naive datetimes. When `self.default_timezone` is not `None`, always return aware datetimes. """ - field_timezone = getattr(self, 'timezone', self.default_timezone()) + field_timezone = getattr(self, "timezone", self.default_timezone()) if field_timezone is not None: if timezone.is_aware(value): try: return value.astimezone(field_timezone) except OverflowError: - self.fail('overflow') + self.fail("overflow") try: return timezone.make_aware(value, field_timezone) except InvalidTimeError: - self.fail('make_aware', timezone=field_timezone) + self.fail("make_aware", timezone=field_timezone) elif (field_timezone is None) and timezone.is_aware(value): return timezone.make_naive(value, utc) return value @@ -1194,10 +1303,14 @@ class DateTimeField(Field): return timezone.get_current_timezone() if settings.USE_TZ else None def to_internal_value(self, value): - input_formats = getattr(self, 'input_formats', api_settings.DATETIME_INPUT_FORMATS) + input_formats = getattr( + self, "input_formats", api_settings.DATETIME_INPUT_FORMATS + ) - if isinstance(value, datetime.date) and not isinstance(value, datetime.datetime): - self.fail('date') + if isinstance(value, datetime.date) and not isinstance( + value, datetime.datetime + ): + self.fail("date") if isinstance(value, datetime.datetime): return self.enforce_timezone(value) @@ -1218,13 +1331,13 @@ class DateTimeField(Field): pass humanized_format = humanize_datetime.datetime_formats(input_formats) - self.fail('invalid', format=humanized_format) + self.fail("invalid", format=humanized_format) def to_representation(self, value): if not value: return None - output_format = getattr(self, 'format', api_settings.DATETIME_FORMAT) + output_format = getattr(self, "format", api_settings.DATETIME_FORMAT) if output_format is None or isinstance(value, six.string_types): return value @@ -1233,16 +1346,18 @@ class DateTimeField(Field): if output_format.lower() == ISO_8601: value = value.isoformat() - if value.endswith('+00:00'): - value = value[:-6] + 'Z' + if value.endswith("+00:00"): + value = value[:-6] + "Z" return value return value.strftime(output_format) class DateField(Field): default_error_messages = { - 'invalid': _('Date has wrong format. Use one of these formats instead: {format}.'), - 'datetime': _('Expected a date but got a datetime.'), + "invalid": _( + "Date has wrong format. Use one of these formats instead: {format}." + ), + "datetime": _("Expected a date but got a datetime."), } datetime_parser = datetime.datetime.strptime @@ -1254,10 +1369,10 @@ class DateField(Field): super(DateField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - input_formats = getattr(self, 'input_formats', api_settings.DATE_INPUT_FORMATS) + input_formats = getattr(self, "input_formats", api_settings.DATE_INPUT_FORMATS) if isinstance(value, datetime.datetime): - self.fail('datetime') + self.fail("datetime") if isinstance(value, datetime.date): return value @@ -1280,13 +1395,13 @@ class DateField(Field): return parsed.date() humanized_format = humanize_datetime.date_formats(input_formats) - self.fail('invalid', format=humanized_format) + self.fail("invalid", format=humanized_format) def to_representation(self, value): if not value: return None - output_format = getattr(self, 'format', api_settings.DATE_FORMAT) + output_format = getattr(self, "format", api_settings.DATE_FORMAT) if output_format is None or isinstance(value, six.string_types): return value @@ -1295,9 +1410,9 @@ class DateField(Field): # not a sensible thing to do, as it means naively dropping # any explicit or implicit timezone info. assert not isinstance(value, datetime.datetime), ( - 'Expected a `date`, but got a `datetime`. Refusing to coerce, ' - 'as this may mean losing timezone information. Use a custom ' - 'read-only field and deal with timezone issues explicitly.' + "Expected a `date`, but got a `datetime`. Refusing to coerce, " + "as this may mean losing timezone information. Use a custom " + "read-only field and deal with timezone issues explicitly." ) if output_format.lower() == ISO_8601: @@ -1308,7 +1423,9 @@ class DateField(Field): class TimeField(Field): default_error_messages = { - 'invalid': _('Time has wrong format. Use one of these formats instead: {format}.'), + "invalid": _( + "Time has wrong format. Use one of these formats instead: {format}." + ) } datetime_parser = datetime.datetime.strptime @@ -1320,7 +1437,7 @@ class TimeField(Field): super(TimeField, self).__init__(*args, **kwargs) def to_internal_value(self, value): - input_formats = getattr(self, 'input_formats', api_settings.TIME_INPUT_FORMATS) + input_formats = getattr(self, "input_formats", api_settings.TIME_INPUT_FORMATS) if isinstance(value, datetime.time): return value @@ -1343,13 +1460,13 @@ class TimeField(Field): return parsed.time() humanized_format = humanize_datetime.time_formats(input_formats) - self.fail('invalid', format=humanized_format) + self.fail("invalid", format=humanized_format) def to_representation(self, value): - if value in (None, ''): + if value in (None, ""): return None - output_format = getattr(self, 'format', api_settings.TIME_FORMAT) + output_format = getattr(self, "format", api_settings.TIME_FORMAT) if output_format is None or isinstance(value, six.string_types): return value @@ -1358,9 +1475,9 @@ class TimeField(Field): # not a sensible thing to do, as it means naively dropping # any explicit or implicit timezone info. assert not isinstance(value, datetime.datetime), ( - 'Expected a `time`, but got a `datetime`. Refusing to coerce, ' - 'as this may mean losing timezone information. Use a custom ' - 'read-only field and deal with timezone issues explicitly.' + "Expected a `time`, but got a `datetime`. Refusing to coerce, " + "as this may mean losing timezone information. Use a custom " + "read-only field and deal with timezone issues explicitly." ) if output_format.lower() == ISO_8601: @@ -1370,27 +1487,27 @@ class TimeField(Field): class DurationField(Field): default_error_messages = { - 'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'), - 'max_value': _('Ensure this value is less than or equal to {max_value}.'), - 'min_value': _('Ensure this value is greater than or equal to {min_value}.'), + "invalid": _( + "Duration has wrong format. Use one of these formats instead: {format}." + ), + "max_value": _("Ensure this value is less than or equal to {max_value}."), + "min_value": _("Ensure this value is greater than or equal to {min_value}."), } def __init__(self, **kwargs): - self.max_value = kwargs.pop('max_value', None) - self.min_value = kwargs.pop('min_value', None) + self.max_value = kwargs.pop("max_value", None) + self.min_value = kwargs.pop("min_value", None) super(DurationField, self).__init__(**kwargs) if self.max_value is not None: - message = lazy( - self.error_messages['max_value'].format, - six.text_type)(max_value=self.max_value) - self.validators.append( - MaxValueValidator(self.max_value, message=message)) + message = lazy(self.error_messages["max_value"].format, six.text_type)( + max_value=self.max_value + ) + self.validators.append(MaxValueValidator(self.max_value, message=message)) if self.min_value is not None: - message = lazy( - self.error_messages['min_value'].format, - six.text_type)(min_value=self.min_value) - self.validators.append( - MinValueValidator(self.min_value, message=message)) + message = lazy(self.error_messages["min_value"].format, six.text_type)( + min_value=self.min_value + ) + self.validators.append(MinValueValidator(self.min_value, message=message)) def to_internal_value(self, value): if isinstance(value, datetime.timedelta): @@ -1398,7 +1515,7 @@ class DurationField(Field): parsed = parse_duration(six.text_type(value)) if parsed is not None: return parsed - self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]') + self.fail("invalid", format="[DD] [HH:[MM:]]ss[.uuuuuu]") def to_representation(self, value): return duration_string(value) @@ -1406,33 +1523,32 @@ class DurationField(Field): # Choice types... + class ChoiceField(Field): - default_error_messages = { - 'invalid_choice': _('"{input}" is not a valid choice.') - } + default_error_messages = {"invalid_choice": _('"{input}" is not a valid choice.')} html_cutoff = None - html_cutoff_text = _('More than {count} items...') + html_cutoff_text = _("More than {count} items...") def __init__(self, choices, **kwargs): self.choices = choices - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) + self.html_cutoff = kwargs.pop("html_cutoff", self.html_cutoff) + self.html_cutoff_text = kwargs.pop("html_cutoff_text", self.html_cutoff_text) - self.allow_blank = kwargs.pop('allow_blank', False) + self.allow_blank = kwargs.pop("allow_blank", False) super(ChoiceField, self).__init__(**kwargs) def to_internal_value(self, data): - if data == '' and self.allow_blank: - return '' + if data == "" and self.allow_blank: + return "" try: return self.choice_strings_to_values[six.text_type(data)] except KeyError: - self.fail('invalid_choice', input=data) + self.fail("invalid_choice", input=data) def to_representation(self, value): - if value in ('', None): + if value in ("", None): return value return self.choice_strings_to_values.get(six.text_type(value), value) @@ -1443,7 +1559,7 @@ class ChoiceField(Field): return iter_options( self.grouped_choices, cutoff=self.html_cutoff, - cutoff_text=self.html_cutoff_text + cutoff_text=self.html_cutoff_text, ) def _get_choices(self): @@ -1465,19 +1581,19 @@ class ChoiceField(Field): class MultipleChoiceField(ChoiceField): default_error_messages = { - 'invalid_choice': _('"{input}" is not a valid choice.'), - 'not_a_list': _('Expected a list of items but got type "{input_type}".'), - 'empty': _('This selection may not be empty.') + "invalid_choice": _('"{input}" is not a valid choice.'), + "not_a_list": _('Expected a list of items but got type "{input_type}".'), + "empty": _("This selection may not be empty."), } default_empty_html = [] def __init__(self, *args, **kwargs): - self.allow_empty = kwargs.pop('allow_empty', True) + self.allow_empty = kwargs.pop("allow_empty", True) super(MultipleChoiceField, self).__init__(*args, **kwargs) def get_value(self, dictionary): if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): + if getattr(self.root, "partial", False): return empty # We override the default field access in order to support # lists in HTML forms. @@ -1486,55 +1602,72 @@ class MultipleChoiceField(ChoiceField): return dictionary.get(self.field_name, empty) def to_internal_value(self, data): - if isinstance(data, six.text_type) or not hasattr(data, '__iter__'): - self.fail('not_a_list', input_type=type(data).__name__) + if isinstance(data, six.text_type) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) if not self.allow_empty and len(data) == 0: - self.fail('empty') + self.fail("empty") return { - super(MultipleChoiceField, self).to_internal_value(item) - for item in data + super(MultipleChoiceField, self).to_internal_value(item) for item in data } def to_representation(self, value): return { - self.choice_strings_to_values.get(six.text_type(item), item) for item in value + self.choice_strings_to_values.get(six.text_type(item), item) + for item in value } class FilePathField(ChoiceField): default_error_messages = { - 'invalid_choice': _('"{input}" is not a valid path choice.') + "invalid_choice": _('"{input}" is not a valid path choice.') } - def __init__(self, path, match=None, recursive=False, allow_files=True, - allow_folders=False, required=None, **kwargs): + def __init__( + self, + path, + match=None, + recursive=False, + allow_files=True, + allow_folders=False, + required=None, + **kwargs + ): # Defer to Django's FilePathField implementation to get the # valid set of choices. field = DjangoFilePathField( - path, match=match, recursive=recursive, allow_files=allow_files, - allow_folders=allow_folders, required=required + path, + match=match, + recursive=recursive, + allow_files=allow_files, + allow_folders=allow_folders, + required=required, ) - kwargs['choices'] = field.choices + kwargs["choices"] = field.choices super(FilePathField, self).__init__(**kwargs) # File types... + class FileField(Field): default_error_messages = { - 'required': _('No file was submitted.'), - 'invalid': _('The submitted data was not a file. Check the encoding type on the form.'), - 'no_name': _('No filename could be determined.'), - 'empty': _('The submitted file is empty.'), - 'max_length': _('Ensure this filename has at most {max_length} characters (it has {length}).'), + "required": _("No file was submitted."), + "invalid": _( + "The submitted data was not a file. Check the encoding type on the form." + ), + "no_name": _("No filename could be determined."), + "empty": _("The submitted file is empty."), + "max_length": _( + "Ensure this filename has at most {max_length} characters (it has {length})." + ), } def __init__(self, *args, **kwargs): - self.max_length = kwargs.pop('max_length', None) - self.allow_empty_file = kwargs.pop('allow_empty_file', False) - if 'use_url' in kwargs: - self.use_url = kwargs.pop('use_url') + self.max_length = kwargs.pop("max_length", None) + self.allow_empty_file = kwargs.pop("allow_empty_file", False) + if "use_url" in kwargs: + self.use_url = kwargs.pop("use_url") super(FileField, self).__init__(*args, **kwargs) def to_internal_value(self, data): @@ -1543,14 +1676,14 @@ class FileField(Field): file_name = data.name file_size = data.size except AttributeError: - self.fail('invalid') + self.fail("invalid") if not file_name: - self.fail('no_name') + self.fail("no_name") if not self.allow_empty_file and not file_size: - self.fail('empty') + self.fail("empty") if self.max_length and len(file_name) > self.max_length: - self.fail('max_length', max_length=self.max_length, length=len(file_name)) + self.fail("max_length", max_length=self.max_length, length=len(file_name)) return data @@ -1558,14 +1691,14 @@ class FileField(Field): if not value: return None - use_url = getattr(self, 'use_url', api_settings.UPLOADED_FILES_USE_URL) + use_url = getattr(self, "use_url", api_settings.UPLOADED_FILES_USE_URL) if use_url: - if not getattr(value, 'url', None): + if not getattr(value, "url", None): # If the file has not been saved it may not have a URL. return None url = value.url - request = self.context.get('request', None) + request = self.context.get("request", None) if request is not None: return request.build_absolute_uri(url) return url @@ -1574,13 +1707,13 @@ class FileField(Field): class ImageField(FileField): default_error_messages = { - 'invalid_image': _( - 'Upload a valid image. The file you uploaded was either not an image or a corrupted image.' - ), + "invalid_image": _( + "Upload a valid image. The file you uploaded was either not an image or a corrupted image." + ) } def __init__(self, *args, **kwargs): - self._DjangoImageField = kwargs.pop('_DjangoImageField', DjangoImageField) + self._DjangoImageField = kwargs.pop("_DjangoImageField", DjangoImageField) super(ImageField, self).__init__(*args, **kwargs) def to_internal_value(self, data): @@ -1595,6 +1728,7 @@ class ImageField(FileField): # Composite field types... + class _UnvalidatedField(Field): def __init__(self, *args, **kwargs): super(_UnvalidatedField, self).__init__(*args, **kwargs) @@ -1612,36 +1746,40 @@ class ListField(Field): child = _UnvalidatedField() initial = [] default_error_messages = { - 'not_a_list': _('Expected a list of items but got type "{input_type}".'), - 'empty': _('This list may not be empty.'), - 'min_length': _('Ensure this field has at least {min_length} elements.'), - 'max_length': _('Ensure this field has no more than {max_length} elements.') + "not_a_list": _('Expected a list of items but got type "{input_type}".'), + "empty": _("This list may not be empty."), + "min_length": _("Ensure this field has at least {min_length} elements."), + "max_length": _("Ensure this field has no more than {max_length} elements."), } def __init__(self, *args, **kwargs): - self.child = kwargs.pop('child', copy.deepcopy(self.child)) - self.allow_empty = kwargs.pop('allow_empty', True) - self.max_length = kwargs.pop('max_length', None) - self.min_length = kwargs.pop('min_length', None) + self.child = kwargs.pop("child", copy.deepcopy(self.child)) + self.allow_empty = kwargs.pop("allow_empty", True) + self.max_length = kwargs.pop("max_length", None) + self.min_length = kwargs.pop("min_length", None) - assert not inspect.isclass(self.child), '`child` has not been instantiated.' + assert not inspect.isclass(self.child), "`child` has not been instantiated." assert self.child.source is None, ( "The `source` argument is not meaningful when applied to a `child=` field. " "Remove `source=` from the field declaration." ) super(ListField, self).__init__(*args, **kwargs) - self.child.bind(field_name='', parent=self) + self.child.bind(field_name="", parent=self) if self.max_length is not None: - message = self.error_messages['max_length'].format(max_length=self.max_length) + message = self.error_messages["max_length"].format( + max_length=self.max_length + ) self.validators.append(MaxLengthValidator(self.max_length, message=message)) if self.min_length is not None: - message = self.error_messages['min_length'].format(min_length=self.min_length) + message = self.error_messages["min_length"].format( + min_length=self.min_length + ) self.validators.append(MinLengthValidator(self.min_length, message=message)) def get_value(self, dictionary): if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): + if getattr(self.root, "partial", False): return empty # We override the default field access in order to support # lists in HTML forms. @@ -1650,7 +1788,9 @@ class ListField(Field): if len(val) > 0: # Support QueryDict lists in HTML input. return val - return html.parse_html_list(dictionary, prefix=self.field_name, default=empty) + return html.parse_html_list( + dictionary, prefix=self.field_name, default=empty + ) return dictionary.get(self.field_name, empty) @@ -1660,17 +1800,20 @@ class ListField(Field): """ if html.is_html_input(data): data = html.parse_html_list(data, default=[]) - if isinstance(data, (six.text_type, Mapping)) or not hasattr(data, '__iter__'): - self.fail('not_a_list', input_type=type(data).__name__) + if isinstance(data, (six.text_type, Mapping)) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) if not self.allow_empty and len(data) == 0: - self.fail('empty') + self.fail("empty") return self.run_child_validation(data) def to_representation(self, data): """ List of object instances -> List of dicts of primitive datatypes. """ - return [self.child.to_representation(item) if item is not None else None for item in data] + return [ + self.child.to_representation(item) if item is not None else None + for item in data + ] def run_child_validation(self, data): result = [] @@ -1691,20 +1834,20 @@ class DictField(Field): child = _UnvalidatedField() initial = {} default_error_messages = { - 'not_a_dict': _('Expected a dictionary of items but got type "{input_type}".') + "not_a_dict": _('Expected a dictionary of items but got type "{input_type}".') } def __init__(self, *args, **kwargs): - self.child = kwargs.pop('child', copy.deepcopy(self.child)) + self.child = kwargs.pop("child", copy.deepcopy(self.child)) - assert not inspect.isclass(self.child), '`child` has not been instantiated.' + assert not inspect.isclass(self.child), "`child` has not been instantiated." assert self.child.source is None, ( "The `source` argument is not meaningful when applied to a `child=` field. " "Remove `source=` from the field declaration." ) super(DictField, self).__init__(*args, **kwargs) - self.child.bind(field_name='', parent=self) + self.child.bind(field_name="", parent=self) def get_value(self, dictionary): # We override the default field access in order to support @@ -1720,12 +1863,14 @@ class DictField(Field): if html.is_html_input(data): data = html.parse_html_dict(data) if not isinstance(data, dict): - self.fail('not_a_dict', input_type=type(data).__name__) + self.fail("not_a_dict", input_type=type(data).__name__) return self.run_child_validation(data) def to_representation(self, value): return { - six.text_type(key): self.child.to_representation(val) if val is not None else None + six.text_type(key): self.child.to_representation(val) + if val is not None + else None for key, val in value.items() } @@ -1758,12 +1903,10 @@ class HStoreField(DictField): class JSONField(Field): - default_error_messages = { - 'invalid': _('Value must be valid JSON.') - } + default_error_messages = {"invalid": _("Value must be valid JSON.")} def __init__(self, *args, **kwargs): - self.binary = kwargs.pop('binary', False) + self.binary = kwargs.pop("binary", False) super(JSONField, self).__init__(*args, **kwargs) def get_value(self, dictionary): @@ -1775,19 +1918,20 @@ class JSONField(Field): ret = six.text_type.__new__(self, value) ret.is_json_string = True return ret + return JSONString(dictionary[self.field_name]) return dictionary.get(self.field_name, empty) def to_internal_value(self, data): try: - if self.binary or getattr(data, 'is_json_string', False): + if self.binary or getattr(data, "is_json_string", False): if isinstance(data, bytes): - data = data.decode('utf-8') + data = data.decode("utf-8") return json.loads(data) else: json.dumps(data) except (TypeError, ValueError): - self.fail('invalid') + self.fail("invalid") return data def to_representation(self, value): @@ -1796,12 +1940,13 @@ class JSONField(Field): # On python 2.x the return type for json.dumps() is underspecified. # On python 3.x json.dumps() returns unicode strings. if isinstance(value, six.text_type): - value = bytes(value.encode('utf-8')) + value = bytes(value.encode("utf-8")) return value # Miscellaneous field types... + class ReadOnlyField(Field): """ A read-only field that simply returns the field value. @@ -1816,7 +1961,7 @@ class ReadOnlyField(Field): """ def __init__(self, **kwargs): - kwargs['read_only'] = True + kwargs["read_only"] = True super(ReadOnlyField, self).__init__(**kwargs) def to_representation(self, value): @@ -1831,9 +1976,10 @@ class HiddenField(Field): constraint on a pair of fields, as we need some way to include the date in the validated data. """ + def __init__(self, **kwargs): - assert 'default' in kwargs, 'default is a required argument.' - kwargs['write_only'] = True + assert "default" in kwargs, "default is a required argument." + kwargs["write_only"] = True super(HiddenField, self).__init__(**kwargs) def get_value(self, dictionary): @@ -1860,22 +2006,23 @@ class SerializerMethodField(Field): def get_extra_info(self, obj): return ... # Calculate some data to return. """ + def __init__(self, method_name=None, **kwargs): self.method_name = method_name - kwargs['source'] = '*' - kwargs['read_only'] = True + kwargs["source"] = "*" + kwargs["read_only"] = True super(SerializerMethodField, self).__init__(**kwargs) def bind(self, field_name, parent): # In order to enforce a consistent style, we error if a redundant # 'method_name' argument has been used. For example: # my_field = serializer.SerializerMethodField(method_name='get_my_field') - default_method_name = 'get_{field_name}'.format(field_name=field_name) + default_method_name = "get_{field_name}".format(field_name=field_name) assert self.method_name != default_method_name, ( "It is redundant to specify `%s` on SerializerMethodField '%s' in " "serializer '%s', because it is the same as the default method name. " - "Remove the `method_name` argument." % - (self.method_name, field_name, parent.__class__.__name__) + "Remove the `method_name` argument." + % (self.method_name, field_name, parent.__class__.__name__) ) # The method name should default to `get_{field_name}`. @@ -1896,22 +2043,22 @@ class ModelField(Field): This is used by `ModelSerializer` when dealing with custom model fields, that do not have a serializer field to be mapped to. """ + default_error_messages = { - 'max_length': _('Ensure this field has no more than {max_length} characters.'), + "max_length": _("Ensure this field has no more than {max_length} characters.") } def __init__(self, model_field, **kwargs): self.model_field = model_field # The `max_length` option is supported by Django's base `Field` class, # so we'd better support it here. - max_length = kwargs.pop('max_length', None) + max_length = kwargs.pop("max_length", None) super(ModelField, self).__init__(**kwargs) if max_length is not None: - message = lazy( - self.error_messages['max_length'].format, - six.text_type)(max_length=self.max_length) - self.validators.append( - MaxLengthValidator(self.max_length, message=message)) + message = lazy(self.error_messages["max_length"].format, six.text_type)( + max_length=self.max_length + ) + self.validators.append(MaxLengthValidator(self.max_length, message=message)) def to_internal_value(self, data): rel = self.model_field.remote_field diff --git a/rest_framework/filters.py b/rest_framework/filters.py index bb1b86586..34d7d6225 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -18,9 +18,7 @@ from django.utils.encoding import force_text from django.utils.translation import ugettext_lazy as _ from rest_framework import RemovedInDRF310Warning -from rest_framework.compat import ( - coreapi, coreschema, distinct, is_guardian_installed -) +from rest_framework.compat import coreapi, coreschema, distinct, is_guardian_installed from rest_framework.settings import api_settings @@ -36,23 +34,22 @@ class BaseFilterBackend(object): raise NotImplementedError(".filter_queryset() must be overridden.") def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" return [] class SearchFilter(BaseFilterBackend): # The URL query parameter used for the search. search_param = api_settings.SEARCH_PARAM - template = 'rest_framework/filters/search.html' - lookup_prefixes = { - '^': 'istartswith', - '=': 'iexact', - '@': 'search', - '$': 'iregex', - } - search_title = _('Search') - search_description = _('A search term.') + template = "rest_framework/filters/search.html" + lookup_prefixes = {"^": "istartswith", "=": "iexact", "@": "search", "$": "iregex"} + search_title = _("Search") + search_description = _("A search term.") def get_search_fields(self, view, request): """ @@ -60,22 +57,22 @@ class SearchFilter(BaseFilterBackend): passed to this method. Sub-classes can override this method to dynamically change the search fields based on request content. """ - return getattr(view, 'search_fields', None) + return getattr(view, "search_fields", None) def get_search_terms(self, request): """ Search terms are set by a ?search=... query parameter, and may be comma and/or whitespace delimited. """ - params = request.query_params.get(self.search_param, '') - return params.replace(',', ' ').split() + params = request.query_params.get(self.search_param, "") + return params.replace(",", " ").split() def construct_search(self, field_name): lookup = self.lookup_prefixes.get(field_name[0]) if lookup: field_name = field_name[1:] else: - lookup = 'icontains' + lookup = "icontains" return LOOKUP_SEP.join([field_name, lookup]) def must_call_distinct(self, queryset, search_fields): @@ -87,12 +84,15 @@ class SearchFilter(BaseFilterBackend): if search_field[0] in self.lookup_prefixes: search_field = search_field[1:] # Annotated fields do not need to be distinct - if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations: + if ( + isinstance(queryset, models.QuerySet) + and search_field in queryset.query.annotations + ): return False parts = search_field.split(LOOKUP_SEP) for part in parts: field = opts.get_field(part) - if hasattr(field, 'get_path_info'): + if hasattr(field, "get_path_info"): # This field is a relation, update opts to follow the relation path_info = field.get_path_info() opts = path_info[-1].to_opts @@ -117,8 +117,7 @@ class SearchFilter(BaseFilterBackend): conditions = [] for search_term in search_terms: queries = [ - models.Q(**{orm_lookup: search_term}) - for orm_lookup in orm_lookups + models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups ] conditions.append(reduce(operator.or_, queries)) queryset = queryset.filter(reduce(operator.and_, conditions)) @@ -132,30 +131,31 @@ class SearchFilter(BaseFilterBackend): return queryset def to_html(self, request, queryset, view): - if not getattr(view, 'search_fields', None): - return '' + if not getattr(view, "search_fields", None): + return "" term = self.get_search_terms(request) - term = term[0] if term else '' - context = { - 'param': self.search_param, - 'term': term - } + term = term[0] if term else "" + context = {"param": self.search_param, "term": term} template = loader.get_template(self.template) return template.render(context) def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" return [ coreapi.Field( name=self.search_param, required=False, - location='query', + location="query", schema=coreschema.String( title=force_text(self.search_title), - description=force_text(self.search_description) - ) + description=force_text(self.search_description), + ), ) ] @@ -164,9 +164,9 @@ class OrderingFilter(BaseFilterBackend): # The URL query parameter used for the ordering. ordering_param = api_settings.ORDERING_PARAM ordering_fields = None - ordering_title = _('Ordering') - ordering_description = _('Which field to use when ordering the results.') - template = 'rest_framework/filters/ordering.html' + ordering_title = _("Ordering") + ordering_description = _("Which field to use when ordering the results.") + template = "rest_framework/filters/ordering.html" def get_ordering(self, request, queryset, view): """ @@ -178,7 +178,7 @@ class OrderingFilter(BaseFilterBackend): """ params = request.query_params.get(self.ordering_param) if params: - fields = [param.strip() for param in params.split(',')] + fields = [param.strip() for param in params.split(",")] ordering = self.remove_invalid_fields(queryset, fields, view, request) if ordering: return ordering @@ -187,7 +187,7 @@ class OrderingFilter(BaseFilterBackend): return self.get_default_ordering(view) def get_default_ordering(self, view): - ordering = getattr(view, 'ordering', None) + ordering = getattr(view, "ordering", None) if isinstance(ordering, six.string_types): return (ordering,) return ordering @@ -195,7 +195,7 @@ class OrderingFilter(BaseFilterBackend): def get_default_valid_fields(self, queryset, view, context={}): # If `ordering_fields` is not specified, then we determine a default # based on the serializer class, if one exists on the view. - if hasattr(view, 'get_serializer_class'): + if hasattr(view, "get_serializer_class"): try: serializer_class = view.get_serializer_class() except AssertionError: @@ -203,7 +203,7 @@ class OrderingFilter(BaseFilterBackend): # no serializer_class was found serializer_class = None else: - serializer_class = getattr(view, 'serializer_class', None) + serializer_class = getattr(view, "serializer_class", None) if serializer_class is None: msg = ( @@ -214,26 +214,26 @@ class OrderingFilter(BaseFilterBackend): raise ImproperlyConfigured(msg % self.__class__.__name__) return [ - (field.source.replace('.', '__') or field_name, field.label) + (field.source.replace(".", "__") or field_name, field.label) for field_name, field in serializer_class(context=context).fields.items() - if not getattr(field, 'write_only', False) and not field.source == '*' + if not getattr(field, "write_only", False) and not field.source == "*" ] def get_valid_fields(self, queryset, view, context={}): - valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) + valid_fields = getattr(view, "ordering_fields", self.ordering_fields) if valid_fields is None: # Default to allowing filtering on serializer fields return self.get_default_valid_fields(queryset, view, context) - elif valid_fields == '__all__': + elif valid_fields == "__all__": # View explicitly allows filtering on any model field valid_fields = [ - (field.name, field.verbose_name) for field in queryset.model._meta.fields + (field.name, field.verbose_name) + for field in queryset.model._meta.fields ] valid_fields += [ - (key, key.title().split('__')) - for key in queryset.query.annotations + (key, key.title().split("__")) for key in queryset.query.annotations ] else: valid_fields = [ @@ -244,8 +244,15 @@ class OrderingFilter(BaseFilterBackend): return valid_fields def remove_invalid_fields(self, queryset, fields, view, request): - valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})] - return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)] + valid_fields = [ + item[0] + for item in self.get_valid_fields(queryset, view, {"request": request}) + ] + return [ + term + for term in fields + if term.lstrip("-") in valid_fields and ORDER_PATTERN.match(term) + ] def filter_queryset(self, request, queryset, view): ordering = self.get_ordering(request, queryset, view) @@ -259,15 +266,11 @@ class OrderingFilter(BaseFilterBackend): current = self.get_ordering(request, queryset, view) current = None if not current else current[0] options = [] - context = { - 'request': request, - 'current': current, - 'param': self.ordering_param, - } + context = {"request": request, "current": current, "param": self.ordering_param} for key, label in self.get_valid_fields(queryset, view, context): - options.append((key, '%s - %s' % (label, _('ascending')))) - options.append(('-' + key, '%s - %s' % (label, _('descending')))) - context['options'] = options + options.append((key, "%s - %s" % (label, _("ascending")))) + options.append(("-" + key, "%s - %s" % (label, _("descending")))) + context["options"] = options return context def to_html(self, request, queryset, view): @@ -276,17 +279,21 @@ class OrderingFilter(BaseFilterBackend): return template.render(context) def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" return [ coreapi.Field( name=self.ordering_param, required=False, - location='query', + location="query", schema=coreschema.String( title=force_text(self.ordering_title), - description=force_text(self.ordering_description) - ) + description=force_text(self.ordering_description), + ), ) ] @@ -296,15 +303,19 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend): A filter backend that limits results to those where the requesting user has read object level permissions. """ + def __init__(self): warnings.warn( "`DjangoObjectPermissionsFilter` has been deprecated and moved to " "the 3rd-party django-rest-framework-guardian package.", - RemovedInDRF310Warning, stacklevel=2 + RemovedInDRF310Warning, + stacklevel=2, ) - assert is_guardian_installed(), 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed' + assert ( + is_guardian_installed() + ), "Using DjangoObjectPermissionsFilter, but django-guardian is not installed" - perm_format = '%(app_label)s.view_%(model_name)s' + perm_format = "%(app_label)s.view_%(model_name)s" def filter_queryset(self, request, queryset, view): # We want to defer this import until run-time, rather than import-time. @@ -317,13 +328,13 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend): user = request.user model_cls = queryset.model kwargs = { - 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.model_name + "app_label": model_cls._meta.app_label, + "model_name": model_cls._meta.model_name, } permission = self.perm_format % kwargs if tuple(guardian_version) >= (1, 3): # Maintain behavior compatibility with versions prior to 1.3 - extra = {'accept_global_perms': False} + extra = {"accept_global_perms": False} else: extra = {} return get_objects_for_user(user, permission, queryset, **extra) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8d0bf284a..e5e132422 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -27,6 +27,7 @@ class GenericAPIView(views.APIView): """ Base class for all other generic views. """ + # You'll need to either set these attributes, # or override `get_queryset()`/`get_serializer_class()`. # If you are overriding a view method, it is important that you call @@ -38,7 +39,7 @@ class GenericAPIView(views.APIView): # If you want to use object lookups other than pk, set 'lookup_field'. # For more complex lookup requirements override `get_object()`. - lookup_field = 'pk' + lookup_field = "pk" lookup_url_kwarg = None # The filter backend classes to use for queryset filtering @@ -64,8 +65,7 @@ class GenericAPIView(views.APIView): """ assert self.queryset is not None, ( "'%s' should either include a `queryset` attribute, " - "or override the `get_queryset()` method." - % self.__class__.__name__ + "or override the `get_queryset()` method." % self.__class__.__name__ ) queryset = self.queryset @@ -88,10 +88,10 @@ class GenericAPIView(views.APIView): lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field assert lookup_url_kwarg in self.kwargs, ( - 'Expected view %s to be called with a URL keyword argument ' + "Expected view %s to be called with a URL keyword argument " 'named "%s". Fix your URL conf, or set the `.lookup_field` ' - 'attribute on the view correctly.' % - (self.__class__.__name__, lookup_url_kwarg) + "attribute on the view correctly." + % (self.__class__.__name__, lookup_url_kwarg) ) filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]} @@ -108,7 +108,7 @@ class GenericAPIView(views.APIView): deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() - kwargs['context'] = self.get_serializer_context() + kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) def get_serializer_class(self): @@ -123,8 +123,7 @@ class GenericAPIView(views.APIView): """ assert self.serializer_class is not None, ( "'%s' should either include a `serializer_class` attribute, " - "or override the `get_serializer_class()` method." - % self.__class__.__name__ + "or override the `get_serializer_class()` method." % self.__class__.__name__ ) return self.serializer_class @@ -133,11 +132,7 @@ class GenericAPIView(views.APIView): """ Extra context provided to the serializer class. """ - return { - 'request': self.request, - 'format': self.format_kwarg, - 'view': self - } + return {"request": self.request, "format": self.format_kwarg, "view": self} def filter_queryset(self, queryset): """ @@ -157,7 +152,7 @@ class GenericAPIView(views.APIView): """ The paginator instance associated with the view, or `None`. """ - if not hasattr(self, '_paginator'): + if not hasattr(self, "_paginator"): if self.pagination_class is None: self._paginator = None else: @@ -183,47 +178,48 @@ class GenericAPIView(views.APIView): # Concrete view classes that provide method handlers # by composing the mixin classes with the base view. -class CreateAPIView(mixins.CreateModelMixin, - GenericAPIView): + +class CreateAPIView(mixins.CreateModelMixin, GenericAPIView): """ Concrete view for creating a model instance. """ + def post(self, request, *args, **kwargs): return self.create(request, *args, **kwargs) -class ListAPIView(mixins.ListModelMixin, - GenericAPIView): +class ListAPIView(mixins.ListModelMixin, GenericAPIView): """ Concrete view for listing a queryset. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) -class RetrieveAPIView(mixins.RetrieveModelMixin, - GenericAPIView): +class RetrieveAPIView(mixins.RetrieveModelMixin, GenericAPIView): """ Concrete view for retrieving a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) -class DestroyAPIView(mixins.DestroyModelMixin, - GenericAPIView): +class DestroyAPIView(mixins.DestroyModelMixin, GenericAPIView): """ Concrete view for deleting a model instance. """ + def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) -class UpdateAPIView(mixins.UpdateModelMixin, - GenericAPIView): +class UpdateAPIView(mixins.UpdateModelMixin, GenericAPIView): """ Concrete view for updating a model instance. """ + def put(self, request, *args, **kwargs): return self.update(request, *args, **kwargs) @@ -231,12 +227,11 @@ class UpdateAPIView(mixins.UpdateModelMixin, return self.partial_update(request, *args, **kwargs) -class ListCreateAPIView(mixins.ListModelMixin, - mixins.CreateModelMixin, - GenericAPIView): +class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, GenericAPIView): """ Concrete view for listing a queryset or creating a model instance. """ + def get(self, request, *args, **kwargs): return self.list(request, *args, **kwargs) @@ -244,12 +239,13 @@ class ListCreateAPIView(mixins.ListModelMixin, return self.create(request, *args, **kwargs) -class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - GenericAPIView): +class RetrieveUpdateAPIView( + mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView +): """ Concrete view for retrieving, updating a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -260,12 +256,13 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, return self.partial_update(request, *args, **kwargs) -class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, - mixins.DestroyModelMixin, - GenericAPIView): +class RetrieveDestroyAPIView( + mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView +): """ Concrete view for retrieving or deleting a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) @@ -273,13 +270,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, return self.destroy(request, *args, **kwargs) -class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - GenericAPIView): +class RetrieveUpdateDestroyAPIView( + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + GenericAPIView, +): """ Concrete view for retrieving, updating or deleting a model instance. """ + def get(self, request, *args, **kwargs): return self.retrieve(request, *args, **kwargs) diff --git a/rest_framework/management/commands/generateschema.py b/rest_framework/management/commands/generateschema.py index 591073ba0..aad9743c3 100644 --- a/rest_framework/management/commands/generateschema.py +++ b/rest_framework/management/commands/generateschema.py @@ -2,7 +2,9 @@ from django.core.management.base import BaseCommand from rest_framework.compat import coreapi from rest_framework.renderers import ( - CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer + CoreJSONRenderer, + JSONOpenAPIRenderer, + OpenAPIRenderer, ) from rest_framework.schemas.generators import SchemaGenerator @@ -11,31 +13,37 @@ class Command(BaseCommand): help = "Generates configured API schema for project." def add_arguments(self, parser): - parser.add_argument('--title', dest="title", default=None, type=str) - parser.add_argument('--url', dest="url", default=None, type=str) - parser.add_argument('--description', dest="description", default=None, type=str) - parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str) + parser.add_argument("--title", dest="title", default=None, type=str) + parser.add_argument("--url", dest="url", default=None, type=str) + parser.add_argument("--description", dest="description", default=None, type=str) + parser.add_argument( + "--format", + dest="format", + choices=["openapi", "openapi-json", "corejson"], + default="openapi", + type=str, + ) def handle(self, *args, **options): - assert coreapi is not None, 'coreapi must be installed.' + assert coreapi is not None, "coreapi must be installed." generator = SchemaGenerator( - url=options['url'], - title=options['title'], - description=options['description'] + url=options["url"], + title=options["title"], + description=options["description"], ) schema = generator.get_schema(request=None, public=True) - renderer = self.get_renderer(options['format']) + renderer = self.get_renderer(options["format"]) output = renderer.render(schema, renderer_context={}) - self.stdout.write(output.decode('utf-8')) + self.stdout.write(output.decode("utf-8")) def get_renderer(self, format): renderer_cls = { - 'corejson': CoreJSONRenderer, - 'openapi': OpenAPIRenderer, - 'openapi-json': JSONOpenAPIRenderer, + "corejson": CoreJSONRenderer, + "openapi": OpenAPIRenderer, + "openapi-json": JSONOpenAPIRenderer, }[format] return renderer_cls() diff --git a/rest_framework/metadata.py b/rest_framework/metadata.py index 9f9324469..76b0370f5 100644 --- a/rest_framework/metadata.py +++ b/rest_framework/metadata.py @@ -35,41 +35,46 @@ class SimpleMetadata(BaseMetadata): There are not any formalized standards for `OPTIONS` responses for us to base this on. """ - label_lookup = ClassLookupDict({ - serializers.Field: 'field', - serializers.BooleanField: 'boolean', - serializers.NullBooleanField: 'boolean', - serializers.CharField: 'string', - serializers.UUIDField: 'string', - serializers.URLField: 'url', - serializers.EmailField: 'email', - serializers.RegexField: 'regex', - serializers.SlugField: 'slug', - serializers.IntegerField: 'integer', - serializers.FloatField: 'float', - serializers.DecimalField: 'decimal', - serializers.DateField: 'date', - serializers.DateTimeField: 'datetime', - serializers.TimeField: 'time', - serializers.ChoiceField: 'choice', - serializers.MultipleChoiceField: 'multiple choice', - serializers.FileField: 'file upload', - serializers.ImageField: 'image upload', - serializers.ListField: 'list', - serializers.DictField: 'nested object', - serializers.Serializer: 'nested object', - }) + + label_lookup = ClassLookupDict( + { + serializers.Field: "field", + serializers.BooleanField: "boolean", + serializers.NullBooleanField: "boolean", + serializers.CharField: "string", + serializers.UUIDField: "string", + serializers.URLField: "url", + serializers.EmailField: "email", + serializers.RegexField: "regex", + serializers.SlugField: "slug", + serializers.IntegerField: "integer", + serializers.FloatField: "float", + serializers.DecimalField: "decimal", + serializers.DateField: "date", + serializers.DateTimeField: "datetime", + serializers.TimeField: "time", + serializers.ChoiceField: "choice", + serializers.MultipleChoiceField: "multiple choice", + serializers.FileField: "file upload", + serializers.ImageField: "image upload", + serializers.ListField: "list", + serializers.DictField: "nested object", + serializers.Serializer: "nested object", + } + ) def determine_metadata(self, request, view): metadata = OrderedDict() - metadata['name'] = view.get_view_name() - metadata['description'] = view.get_view_description() - metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes] - metadata['parses'] = [parser.media_type for parser in view.parser_classes] - if hasattr(view, 'get_serializer'): + metadata["name"] = view.get_view_name() + metadata["description"] = view.get_view_description() + metadata["renders"] = [ + renderer.media_type for renderer in view.renderer_classes + ] + metadata["parses"] = [parser.media_type for parser in view.parser_classes] + if hasattr(view, "get_serializer"): actions = self.determine_actions(request, view) if actions: - metadata['actions'] = actions + metadata["actions"] = actions return metadata def determine_actions(self, request, view): @@ -78,14 +83,14 @@ class SimpleMetadata(BaseMetadata): the fields that are accepted for 'PUT' and 'POST' methods. """ actions = {} - for method in {'PUT', 'POST'} & set(view.allowed_methods): + for method in {"PUT", "POST"} & set(view.allowed_methods): view.request = clone_request(request, method) try: # Test global permissions - if hasattr(view, 'check_permissions'): + if hasattr(view, "check_permissions"): view.check_permissions(view.request) # Test object permissions - if method == 'PUT' and hasattr(view, 'get_object'): + if method == "PUT" and hasattr(view, "get_object"): view.get_object() except (exceptions.APIException, PermissionDenied, Http404): pass @@ -104,15 +109,17 @@ class SimpleMetadata(BaseMetadata): Given an instance of a serializer, return a dictionary of metadata about its fields. """ - if hasattr(serializer, 'child'): + if hasattr(serializer, "child"): # If this is a `ListSerializer` then we want to examine the # underlying child serializer instance instead. serializer = serializer.child - return OrderedDict([ - (field_name, self.get_field_info(field)) - for field_name, field in serializer.fields.items() - if not isinstance(field, serializers.HiddenField) - ]) + return OrderedDict( + [ + (field_name, self.get_field_info(field)) + for field_name, field in serializer.fields.items() + if not isinstance(field, serializers.HiddenField) + ] + ) def get_field_info(self, field): """ @@ -120,32 +127,40 @@ class SimpleMetadata(BaseMetadata): of metadata about it. """ field_info = OrderedDict() - field_info['type'] = self.label_lookup[field] - field_info['required'] = getattr(field, 'required', False) + field_info["type"] = self.label_lookup[field] + field_info["required"] = getattr(field, "required", False) attrs = [ - 'read_only', 'label', 'help_text', - 'min_length', 'max_length', - 'min_value', 'max_value' + "read_only", + "label", + "help_text", + "min_length", + "max_length", + "min_value", + "max_value", ] for attr in attrs: value = getattr(field, attr, None) - if value is not None and value != '': + if value is not None and value != "": field_info[attr] = force_text(value, strings_only=True) - if getattr(field, 'child', None): - field_info['child'] = self.get_field_info(field.child) - elif getattr(field, 'fields', None): - field_info['children'] = self.get_serializer_info(field) + if getattr(field, "child", None): + field_info["child"] = self.get_field_info(field.child) + elif getattr(field, "fields", None): + field_info["children"] = self.get_serializer_info(field) - if (not field_info.get('read_only') and - not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and - hasattr(field, 'choices')): - field_info['choices'] = [ + if ( + not field_info.get("read_only") + and not isinstance( + field, (serializers.RelatedField, serializers.ManyRelatedField) + ) + and hasattr(field, "choices") + ): + field_info["choices"] = [ { - 'value': choice_value, - 'display_name': force_text(choice_name, strings_only=True) + "value": choice_value, + "display_name": force_text(choice_name, strings_only=True), } for choice_value, choice_name in field.choices.items() ] diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index de10d6930..4855140ad 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -15,19 +15,22 @@ class CreateModelMixin(object): """ Create a model instance. """ + def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + return Response( + serializer.data, status=status.HTTP_201_CREATED, headers=headers + ) def perform_create(self, serializer): serializer.save() def get_success_headers(self, data): try: - return {'Location': str(data[api_settings.URL_FIELD_NAME])} + return {"Location": str(data[api_settings.URL_FIELD_NAME])} except (TypeError, KeyError): return {} @@ -36,6 +39,7 @@ class ListModelMixin(object): """ List a queryset. """ + def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) @@ -52,6 +56,7 @@ class RetrieveModelMixin(object): """ Retrieve a model instance. """ + def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) @@ -62,14 +67,15 @@ class UpdateModelMixin(object): """ Update a model instance. """ + def update(self, request, *args, **kwargs): - partial = kwargs.pop('partial', False) + partial = kwargs.pop("partial", False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) - if getattr(instance, '_prefetched_objects_cache', None): + if getattr(instance, "_prefetched_objects_cache", None): # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} @@ -80,7 +86,7 @@ class UpdateModelMixin(object): serializer.save() def partial_update(self, request, *args, **kwargs): - kwargs['partial'] = True + kwargs["partial"] = True return self.update(request, *args, **kwargs) @@ -88,6 +94,7 @@ class DestroyModelMixin(object): """ Destroy a model instance. """ + def destroy(self, request, *args, **kwargs): instance = self.get_object() self.perform_destroy(instance) diff --git a/rest_framework/negotiation.py b/rest_framework/negotiation.py index ca1b59f12..e9d89dd14 100644 --- a/rest_framework/negotiation.py +++ b/rest_framework/negotiation.py @@ -9,16 +9,18 @@ from django.http import Http404 from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework.settings import api_settings from rest_framework.utils.mediatypes import ( - _MediaType, media_type_matches, order_by_precedence + _MediaType, + media_type_matches, + order_by_precedence, ) class BaseContentNegotiation(object): def select_parser(self, request, parsers): - raise NotImplementedError('.select_parser() must be implemented') + raise NotImplementedError(".select_parser() must be implemented") def select_renderer(self, request, renderers, format_suffix=None): - raise NotImplementedError('.select_renderer() must be implemented') + raise NotImplementedError(".select_renderer() must be implemented") class DefaultContentNegotiation(BaseContentNegotiation): @@ -59,16 +61,20 @@ class DefaultContentNegotiation(BaseContentNegotiation): # Return the most specific media type as accepted. media_type_wrapper = _MediaType(media_type) if ( - _MediaType(renderer.media_type).precedence > - media_type_wrapper.precedence + _MediaType(renderer.media_type).precedence + > media_type_wrapper.precedence ): # Eg client requests '*/*' # Accepted media type is 'application/json' - full_media_type = ';'.join( - (renderer.media_type,) + - tuple('{0}={1}'.format( - key, value.decode(HTTP_HEADER_ENCODING)) - for key, value in media_type_wrapper.params.items())) + full_media_type = ";".join( + (renderer.media_type,) + + tuple( + "{0}={1}".format( + key, value.decode(HTTP_HEADER_ENCODING) + ) + for key, value in media_type_wrapper.params.items() + ) + ) return renderer, full_media_type else: # Eg client requests 'application/json; indent=8' @@ -82,8 +88,7 @@ class DefaultContentNegotiation(BaseContentNegotiation): If there is a '.json' style format suffix, filter the renderers so that we only negotiation against those that accept that format. """ - renderers = [renderer for renderer in renderers - if renderer.format == format] + renderers = [renderer for renderer in renderers if renderer.format == format] if not renderers: raise Http404 return renderers @@ -93,5 +98,5 @@ class DefaultContentNegotiation(BaseContentNegotiation): Given the incoming request, return a tokenized list of media type strings. """ - header = request.META.get('HTTP_ACCEPT', '*/*') - return [token.strip() for token in header.split(',')] + header = request.META.get("HTTP_ACCEPT", "*/*") + return [token.strip() for token in header.split(",")] diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index b11d7cdf3..5b61a4b79 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -8,8 +8,7 @@ from __future__ import unicode_literals from base64 import b64decode, b64encode from collections import OrderedDict, namedtuple -from django.core.paginator import InvalidPage -from django.core.paginator import Paginator as DjangoPaginator +from django.core.paginator import InvalidPage, Paginator as DjangoPaginator from django.template import loader from django.utils import six from django.utils.encoding import force_text @@ -83,10 +82,7 @@ def _get_displayed_page_numbers(current, final): included.add(final - 2) # Now sort the page numbers and drop anything outside the limits. - included = [ - idx for idx in sorted(list(included)) - if 0 < idx <= final - ] + included = [idx for idx in sorted(list(included)) if 0 < idx <= final] # Finally insert any `...` breaks if current > 4: @@ -110,7 +106,7 @@ def _get_page_links(page_numbers, current, url_func): url=url_func(page_number), number=page_number, is_active=(page_number == current), - is_break=False + is_break=False, ) page_links.append(page_link) return page_links @@ -121,14 +117,15 @@ def _reverse_ordering(ordering_tuple): Given an order_by tuple such as `('-created', 'uuid')` reverse the ordering and return a new tuple, eg. `('created', '-uuid')`. """ + def invert(x): - return x[1:] if x.startswith('-') else '-' + x + return x[1:] if x.startswith("-") else "-" + x return tuple([invert(item) for item in ordering_tuple]) -Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position']) -PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break']) +Cursor = namedtuple("Cursor", ["offset", "reverse", "position"]) +PageLink = namedtuple("PageLink", ["url", "number", "is_active", "is_break"]) PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True) @@ -137,19 +134,23 @@ class BasePagination(object): display_page_controls = False def paginate_queryset(self, queryset, request, view=None): # pragma: no cover - raise NotImplementedError('paginate_queryset() must be implemented.') + raise NotImplementedError("paginate_queryset() must be implemented.") def get_paginated_response(self, data): # pragma: no cover - raise NotImplementedError('get_paginated_response() must be implemented.') + raise NotImplementedError("get_paginated_response() must be implemented.") def to_html(self): # pragma: no cover - raise NotImplementedError('to_html() must be implemented to display page controls.') + raise NotImplementedError( + "to_html() must be implemented to display page controls." + ) def get_results(self, data): - return data['results'] + return data["results"] def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" return [] @@ -161,6 +162,7 @@ class PageNumberPagination(BasePagination): http://api.example.org/accounts/?page=4 http://api.example.org/accounts/?page=4&page_size=100 """ + # The default page size. # Defaults to `None`, meaning pagination is disabled. page_size = api_settings.PAGE_SIZE @@ -168,23 +170,23 @@ class PageNumberPagination(BasePagination): django_paginator_class = DjangoPaginator # Client can control the page using this query parameter. - page_query_param = 'page' - page_query_description = _('A page number within the paginated result set.') + page_query_param = "page" + page_query_description = _("A page number within the paginated result set.") # Client can control the page size using this query parameter. # Default is 'None'. Set to eg 'page_size' to enable usage. page_size_query_param = None - page_size_query_description = _('Number of results to return per page.') + page_size_query_description = _("Number of results to return per page.") # Set to an integer to limit the maximum page size the client may request. # Only relevant if 'page_size_query_param' has also been set. max_page_size = None - last_page_strings = ('last',) + last_page_strings = ("last",) - template = 'rest_framework/pagination/numbers.html' + template = "rest_framework/pagination/numbers.html" - invalid_page_message = _('Invalid page.') + invalid_page_message = _("Invalid page.") def paginate_queryset(self, queryset, request, view=None): """ @@ -216,12 +218,16 @@ class PageNumberPagination(BasePagination): return list(self.page) def get_paginated_response(self, data): - return Response(OrderedDict([ - ('count', self.page.paginator.count), - ('next', self.get_next_link()), - ('previous', self.get_previous_link()), - ('results', data) - ])) + return Response( + OrderedDict( + [ + ("count", self.page.paginator.count), + ("next", self.get_next_link()), + ("previous", self.get_previous_link()), + ("results", data), + ] + ) + ) def get_page_size(self, request): if self.page_size_query_param: @@ -229,7 +235,7 @@ class PageNumberPagination(BasePagination): return _positive_int( request.query_params[self.page_size_query_param], strict=True, - cutoff=self.max_page_size + cutoff=self.max_page_size, ) except (KeyError, ValueError): pass @@ -267,9 +273,9 @@ class PageNumberPagination(BasePagination): page_links = _get_page_links(page_numbers, current, page_number_to_url) return { - 'previous_url': self.get_previous_link(), - 'next_url': self.get_next_link(), - 'page_links': page_links + "previous_url": self.get_previous_link(), + "next_url": self.get_next_link(), + "page_links": page_links, } def to_html(self): @@ -278,17 +284,20 @@ class PageNumberPagination(BasePagination): return template.render(context) def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" fields = [ coreapi.Field( name=self.page_query_param, required=False, - location='query', + location="query", schema=coreschema.Integer( - title='Page', - description=force_text(self.page_query_description) - ) + title="Page", description=force_text(self.page_query_description) + ), ) ] if self.page_size_query_param is not None: @@ -296,11 +305,11 @@ class PageNumberPagination(BasePagination): coreapi.Field( name=self.page_size_query_param, required=False, - location='query', + location="query", schema=coreschema.Integer( - title='Page size', - description=force_text(self.page_size_query_description) - ) + title="Page size", + description=force_text(self.page_size_query_description), + ), ) ) return fields @@ -313,13 +322,14 @@ class LimitOffsetPagination(BasePagination): http://api.example.org/accounts/?limit=100 http://api.example.org/accounts/?offset=400&limit=100 """ + default_limit = api_settings.PAGE_SIZE - limit_query_param = 'limit' - limit_query_description = _('Number of results to return per page.') - offset_query_param = 'offset' - offset_query_description = _('The initial index from which to return the results.') + limit_query_param = "limit" + limit_query_description = _("Number of results to return per page.") + offset_query_param = "offset" + offset_query_description = _("The initial index from which to return the results.") max_limit = None - template = 'rest_framework/pagination/numbers.html' + template = "rest_framework/pagination/numbers.html" def paginate_queryset(self, queryset, request, view=None): self.count = self.get_count(queryset) @@ -334,15 +344,19 @@ class LimitOffsetPagination(BasePagination): if self.count == 0 or self.offset > self.count: return [] - return list(queryset[self.offset:self.offset + self.limit]) + return list(queryset[self.offset : self.offset + self.limit]) def get_paginated_response(self, data): - return Response(OrderedDict([ - ('count', self.count), - ('next', self.get_next_link()), - ('previous', self.get_previous_link()), - ('results', data) - ])) + return Response( + OrderedDict( + [ + ("count", self.count), + ("next", self.get_next_link()), + ("previous", self.get_previous_link()), + ("results", data), + ] + ) + ) def get_limit(self, request): if self.limit_query_param: @@ -350,7 +364,7 @@ class LimitOffsetPagination(BasePagination): return _positive_int( request.query_params[self.limit_query_param], strict=True, - cutoff=self.max_limit + cutoff=self.max_limit, ) except (KeyError, ValueError): pass @@ -359,9 +373,7 @@ class LimitOffsetPagination(BasePagination): def get_offset(self, request): try: - return _positive_int( - request.query_params[self.offset_query_param], - ) + return _positive_int(request.query_params[self.offset_query_param]) except (KeyError, ValueError): return 0 @@ -399,10 +411,9 @@ class LimitOffsetPagination(BasePagination): # plus the number of pages up to the current offset. # When offset is not strictly divisible by the limit then we may # end up introducing an extra page as an artifact. - final = ( - _divide_with_ceil(self.count - self.offset, self.limit) + - _divide_with_ceil(self.offset, self.limit) - ) + final = _divide_with_ceil( + self.count - self.offset, self.limit + ) + _divide_with_ceil(self.offset, self.limit) if final < 1: final = 1 @@ -424,9 +435,9 @@ class LimitOffsetPagination(BasePagination): page_links = _get_page_links(page_numbers, current, page_number_to_url) return { - 'previous_url': self.get_previous_link(), - 'next_url': self.get_next_link(), - 'page_links': page_links + "previous_url": self.get_previous_link(), + "next_url": self.get_next_link(), + "page_links": page_links, } def to_html(self): @@ -435,27 +446,30 @@ class LimitOffsetPagination(BasePagination): return template.render(context) def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" return [ coreapi.Field( name=self.limit_query_param, required=False, - location='query', + location="query", schema=coreschema.Integer( - title='Limit', - description=force_text(self.limit_query_description) - ) + title="Limit", description=force_text(self.limit_query_description) + ), ), coreapi.Field( name=self.offset_query_param, required=False, - location='query', + location="query", schema=coreschema.Integer( - title='Offset', - description=force_text(self.offset_query_description) - ) - ) + title="Offset", + description=force_text(self.offset_query_description), + ), + ), ] def get_count(self, queryset): @@ -474,17 +488,18 @@ class CursorPagination(BasePagination): For an overview of the position/offset style we use, see this post: https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api """ - cursor_query_param = 'cursor' - cursor_query_description = _('The pagination cursor value.') + + cursor_query_param = "cursor" + cursor_query_description = _("The pagination cursor value.") page_size = api_settings.PAGE_SIZE - invalid_cursor_message = _('Invalid cursor') - ordering = '-created' - template = 'rest_framework/pagination/previous_and_next.html' + invalid_cursor_message = _("Invalid cursor") + ordering = "-created" + template = "rest_framework/pagination/previous_and_next.html" # Client can control the page size using this query parameter. # Default is 'None'. Set to eg 'page_size' to enable usage. page_size_query_param = None - page_size_query_description = _('Number of results to return per page.') + page_size_query_description = _("Number of results to return per page.") # Set to an integer to limit the maximum page size the client may request. # Only relevant if 'page_size_query_param' has also been set. @@ -519,27 +534,29 @@ class CursorPagination(BasePagination): # If we have a cursor with a fixed position then filter by that. if current_position is not None: order = self.ordering[0] - is_reversed = order.startswith('-') - order_attr = order.lstrip('-') + is_reversed = order.startswith("-") + order_attr = order.lstrip("-") # Test for: (cursor reversed) XOR (queryset reversed) if self.cursor.reverse != is_reversed: - kwargs = {order_attr + '__lt': current_position} + kwargs = {order_attr + "__lt": current_position} else: - kwargs = {order_attr + '__gt': current_position} + kwargs = {order_attr + "__gt": current_position} queryset = queryset.filter(**kwargs) # If we have an offset cursor then offset the entire page by that amount. # We also always fetch an extra item in order to determine if there is a # page following on from this one. - results = list(queryset[offset:offset + self.page_size + 1]) - self.page = list(results[:self.page_size]) + results = list(queryset[offset : offset + self.page_size + 1]) + self.page = list(results[: self.page_size]) # Determine the position of the final item following the page. if len(results) > len(self.page): has_following_position = True - following_position = self._get_position_from_instance(results[-1], self.ordering) + following_position = self._get_position_from_instance( + results[-1], self.ordering + ) else: has_following_position = False following_position = None @@ -578,7 +595,7 @@ class CursorPagination(BasePagination): return _positive_int( request.query_params[self.page_size_query_param], strict=True, - cutoff=self.max_page_size + cutoff=self.max_page_size, ) except (KeyError, ValueError): pass @@ -686,8 +703,9 @@ class CursorPagination(BasePagination): Return a tuple of strings, that may be used in an `order_by` method. """ ordering_filters = [ - filter_cls for filter_cls in getattr(view, 'filter_backends', []) - if hasattr(filter_cls, 'get_ordering') + filter_cls + for filter_cls in getattr(view, "filter_backends", []) + if hasattr(filter_cls, "get_ordering") ] if ordering_filters: @@ -697,29 +715,27 @@ class CursorPagination(BasePagination): filter_instance = filter_cls() ordering = filter_instance.get_ordering(request, queryset, view) assert ordering is not None, ( - 'Using cursor pagination, but filter class {filter_cls} ' - 'returned a `None` ordering.'.format( - filter_cls=filter_cls.__name__ - ) + "Using cursor pagination, but filter class {filter_cls} " + "returned a `None` ordering.".format(filter_cls=filter_cls.__name__) ) else: # The default case is to check for an `ordering` attribute # on this pagination instance. ordering = self.ordering assert ordering is not None, ( - 'Using cursor pagination, but no ordering attribute was declared ' - 'on the pagination class.' + "Using cursor pagination, but no ordering attribute was declared " + "on the pagination class." ) - assert '__' not in ordering, ( - 'Cursor pagination does not support double underscore lookups ' - 'for orderings. Orderings should be an unchanging, unique or ' + assert "__" not in ordering, ( + "Cursor pagination does not support double underscore lookups " + "for orderings. Orderings should be an unchanging, unique or " 'nearly-unique field on the model, such as "-created" or "pk".' ) - assert isinstance(ordering, (six.string_types, list, tuple)), ( - 'Invalid ordering. Expected string or tuple, but got {type}'.format( - type=type(ordering).__name__ - ) + assert isinstance( + ordering, (six.string_types, list, tuple) + ), "Invalid ordering. Expected string or tuple, but got {type}".format( + type=type(ordering).__name__ ) if isinstance(ordering, six.string_types): @@ -736,16 +752,16 @@ class CursorPagination(BasePagination): return None try: - querystring = b64decode(encoded.encode('ascii')).decode('ascii') + querystring = b64decode(encoded.encode("ascii")).decode("ascii") tokens = urlparse.parse_qs(querystring, keep_blank_values=True) - offset = tokens.get('o', ['0'])[0] + offset = tokens.get("o", ["0"])[0] offset = _positive_int(offset, cutoff=self.offset_cutoff) - reverse = tokens.get('r', ['0'])[0] + reverse = tokens.get("r", ["0"])[0] reverse = bool(int(reverse)) - position = tokens.get('p', [None])[0] + position = tokens.get("p", [None])[0] except (TypeError, ValueError): raise NotFound(self.invalid_cursor_message) @@ -757,18 +773,18 @@ class CursorPagination(BasePagination): """ tokens = {} if cursor.offset != 0: - tokens['o'] = str(cursor.offset) + tokens["o"] = str(cursor.offset) if cursor.reverse: - tokens['r'] = '1' + tokens["r"] = "1" if cursor.position is not None: - tokens['p'] = cursor.position + tokens["p"] = cursor.position querystring = urlparse.urlencode(tokens, doseq=True) - encoded = b64encode(querystring.encode('ascii')).decode('ascii') + encoded = b64encode(querystring.encode("ascii")).decode("ascii") return replace_query_param(self.base_url, self.cursor_query_param, encoded) def _get_position_from_instance(self, instance, ordering): - field_name = ordering[0].lstrip('-') + field_name = ordering[0].lstrip("-") if isinstance(instance, dict): attr = instance[field_name] else: @@ -776,16 +792,20 @@ class CursorPagination(BasePagination): return six.text_type(attr) def get_paginated_response(self, data): - return Response(OrderedDict([ - ('next', self.get_next_link()), - ('previous', self.get_previous_link()), - ('results', data) - ])) + return Response( + OrderedDict( + [ + ("next", self.get_next_link()), + ("previous", self.get_previous_link()), + ("results", data), + ] + ) + ) def get_html_context(self): return { - 'previous_url': self.get_previous_link(), - 'next_url': self.get_next_link() + "previous_url": self.get_previous_link(), + "next_url": self.get_next_link(), } def to_html(self): @@ -794,17 +814,21 @@ class CursorPagination(BasePagination): return template.render(context) def get_schema_fields(self, view): - assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' - assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' + assert ( + coreapi is not None + ), "coreapi must be installed to use `get_schema_fields()`" + assert ( + coreschema is not None + ), "coreschema must be installed to use `get_schema_fields()`" fields = [ coreapi.Field( name=self.cursor_query_param, required=False, - location='query', + location="query", schema=coreschema.String( - title='Cursor', - description=force_text(self.cursor_query_description) - ) + title="Cursor", + description=force_text(self.cursor_query_description), + ), ) ] if self.page_size_query_param is not None: @@ -812,11 +836,11 @@ class CursorPagination(BasePagination): coreapi.Field( name=self.page_size_query_param, required=False, - location='query', + location="query", schema=coreschema.Integer( - title='Page size', - description=force_text(self.page_size_query_description) - ) + title="Page size", + description=force_text(self.page_size_query_description), + ), ) ) return fields diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index 35d0d1aa7..22307ae2a 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -11,10 +11,12 @@ import codecs from django.conf import settings from django.core.files.uploadhandler import StopFutureHandlers from django.http import QueryDict -from django.http.multipartparser import ChunkIter -from django.http.multipartparser import \ - MultiPartParser as DjangoMultiPartParser -from django.http.multipartparser import MultiPartParserError, parse_header +from django.http.multipartparser import ( + ChunkIter, + MultiPartParser as DjangoMultiPartParser, + MultiPartParserError, + parse_header, +) from django.utils import six from django.utils.encoding import force_text from django.utils.six.moves.urllib import parse as urlparse @@ -36,6 +38,7 @@ class BaseParser(object): All parsers should extend `BaseParser`, specifying a `media_type` attribute, and overriding the `.parse()` method. """ + media_type = None def parse(self, stream, media_type=None, parser_context=None): @@ -51,7 +54,8 @@ class JSONParser(BaseParser): """ Parses JSON-serialized data. """ - media_type = 'application/json' + + media_type = "application/json" renderer_class = renderers.JSONRenderer strict = api_settings.STRICT_JSON @@ -60,21 +64,22 @@ class JSONParser(BaseParser): Parses the incoming bytestream as JSON and returns the resulting data. """ parser_context = parser_context or {} - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) try: decoded_stream = codecs.getreader(encoding)(stream) parse_constant = json.strict_constant if self.strict else None return json.load(decoded_stream, parse_constant=parse_constant) except ValueError as exc: - raise ParseError('JSON parse error - %s' % six.text_type(exc)) + raise ParseError("JSON parse error - %s" % six.text_type(exc)) class FormParser(BaseParser): """ Parser for form data. """ - media_type = 'application/x-www-form-urlencoded' + + media_type = "application/x-www-form-urlencoded" def parse(self, stream, media_type=None, parser_context=None): """ @@ -82,7 +87,7 @@ class FormParser(BaseParser): and returns the resulting QueryDict. """ parser_context = parser_context or {} - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) data = QueryDict(stream.read(), encoding=encoding) return data @@ -91,7 +96,8 @@ class MultiPartParser(BaseParser): """ Parser for multipart form data, which may include file data. """ - media_type = 'multipart/form-data' + + media_type = "multipart/form-data" def parse(self, stream, media_type=None, parser_context=None): """ @@ -102,10 +108,10 @@ class MultiPartParser(BaseParser): `.files` will be a `QueryDict` containing all the form files. """ parser_context = parser_context or {} - request = parser_context['request'] - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + request = parser_context["request"] + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) meta = request.META.copy() - meta['CONTENT_TYPE'] = media_type + meta["CONTENT_TYPE"] = media_type upload_handlers = request.upload_handlers try: @@ -113,17 +119,18 @@ class MultiPartParser(BaseParser): data, files = parser.parse() return DataAndFiles(data, files) except MultiPartParserError as exc: - raise ParseError('Multipart form parse error - %s' % six.text_type(exc)) + raise ParseError("Multipart form parse error - %s" % six.text_type(exc)) class FileUploadParser(BaseParser): """ Parser for file upload data. """ - media_type = '*/*' + + media_type = "*/*" errors = { - 'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream', - 'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.', + "unhandled": "FileUpload parse error - none of upload handlers can handle the stream", + "no_filename": "Missing filename. Request should include a Content-Disposition header with a filename parameter.", } def parse(self, stream, media_type=None, parser_context=None): @@ -135,34 +142,32 @@ class FileUploadParser(BaseParser): `.files` will be a `QueryDict` containing one 'file' element. """ parser_context = parser_context or {} - request = parser_context['request'] - encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET) + request = parser_context["request"] + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) meta = request.META upload_handlers = request.upload_handlers filename = self.get_filename(stream, media_type, parser_context) if not filename: - raise ParseError(self.errors['no_filename']) + raise ParseError(self.errors["no_filename"]) # Note that this code is extracted from Django's handling of # file uploads in MultiPartParser. - content_type = meta.get('HTTP_CONTENT_TYPE', - meta.get('CONTENT_TYPE', '')) + content_type = meta.get("HTTP_CONTENT_TYPE", meta.get("CONTENT_TYPE", "")) try: - content_length = int(meta.get('HTTP_CONTENT_LENGTH', - meta.get('CONTENT_LENGTH', 0))) + content_length = int( + meta.get("HTTP_CONTENT_LENGTH", meta.get("CONTENT_LENGTH", 0)) + ) except (ValueError, TypeError): content_length = None # See if the handler will want to take care of the parsing. for handler in upload_handlers: - result = handler.handle_raw_input(stream, - meta, - content_length, - None, - encoding) + result = handler.handle_raw_input( + stream, meta, content_length, None, encoding + ) if result is not None: - return DataAndFiles({}, {'file': result[1]}) + return DataAndFiles({}, {"file": result[1]}) # This is the standard case. possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size] @@ -172,10 +177,9 @@ class FileUploadParser(BaseParser): for index, handler in enumerate(upload_handlers): try: - handler.new_file(None, filename, content_type, - content_length, encoding) + handler.new_file(None, filename, content_type, content_length, encoding) except StopFutureHandlers: - upload_handlers = upload_handlers[:index + 1] + upload_handlers = upload_handlers[: index + 1] break for chunk in chunks: @@ -189,9 +193,9 @@ class FileUploadParser(BaseParser): for index, handler in enumerate(upload_handlers): file_obj = handler.file_complete(counters[index]) if file_obj is not None: - return DataAndFiles({}, {'file': file_obj}) + return DataAndFiles({}, {"file": file_obj}) - raise ParseError(self.errors['unhandled']) + raise ParseError(self.errors["unhandled"]) def get_filename(self, stream, media_type, parser_context): """ @@ -199,17 +203,17 @@ class FileUploadParser(BaseParser): Then tries to parse Content-Disposition header. """ try: - return parser_context['kwargs']['filename'] + return parser_context["kwargs"]["filename"] except KeyError: pass try: - meta = parser_context['request'].META - disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8')) + meta = parser_context["request"].META + disposition = parse_header(meta["HTTP_CONTENT_DISPOSITION"].encode("utf-8")) filename_parm = disposition[1] - if 'filename*' in filename_parm: + if "filename*" in filename_parm: return self.get_encoded_filename(filename_parm) - return force_text(filename_parm['filename']) + return force_text(filename_parm["filename"]) except (AttributeError, KeyError, ValueError): pass @@ -218,10 +222,10 @@ class FileUploadParser(BaseParser): Handle encoded filenames per RFC6266. See also: https://tools.ietf.org/html/rfc2231#section-4 """ - encoded_filename = force_text(filename_parm['filename*']) + encoded_filename = force_text(filename_parm["filename*"]) try: - charset, lang, filename = encoded_filename.split('\'', 2) + charset, lang, filename = encoded_filename.split("'", 2) filename = urlparse.unquote(filename) except (ValueError, LookupError): - filename = force_text(filename_parm['filename']) + filename = force_text(filename_parm["filename"]) return filename diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 5d75f54ba..aff42caab 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -8,7 +8,8 @@ from django.utils import six from rest_framework import exceptions -SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') + +SAFE_METHODS = ("GET", "HEAD", "OPTIONS") class OperationHolderMixin: @@ -56,16 +57,14 @@ class AND: self.op2 = op2 def has_permission(self, request, view): - return ( - self.op1.has_permission(request, view) and - self.op2.has_permission(request, view) + return self.op1.has_permission(request, view) and self.op2.has_permission( + request, view ) def has_object_permission(self, request, view, obj): - return ( - self.op1.has_object_permission(request, view, obj) and - self.op2.has_object_permission(request, view, obj) - ) + return self.op1.has_object_permission( + request, view, obj + ) and self.op2.has_object_permission(request, view, obj) class OR: @@ -74,16 +73,14 @@ class OR: self.op2 = op2 def has_permission(self, request, view): - return ( - self.op1.has_permission(request, view) or - self.op2.has_permission(request, view) + return self.op1.has_permission(request, view) or self.op2.has_permission( + request, view ) def has_object_permission(self, request, view, obj): - return ( - self.op1.has_object_permission(request, view, obj) or - self.op2.has_object_permission(request, view, obj) - ) + return self.op1.has_object_permission( + request, view, obj + ) or self.op2.has_object_permission(request, view, obj) class NOT: @@ -157,9 +154,9 @@ class IsAuthenticatedOrReadOnly(BasePermission): def has_permission(self, request, view): return bool( - request.method in SAFE_METHODS or - request.user and - request.user.is_authenticated + request.method in SAFE_METHODS + or request.user + and request.user.is_authenticated ) @@ -179,13 +176,13 @@ class DjangoModelPermissions(BasePermission): # Override this if you need to also provide 'view' permissions, # or if you want to provide custom permission codes. perms_map = { - 'GET': [], - 'OPTIONS': [], - 'HEAD': [], - 'POST': ['%(app_label)s.add_%(model_name)s'], - 'PUT': ['%(app_label)s.change_%(model_name)s'], - 'PATCH': ['%(app_label)s.change_%(model_name)s'], - 'DELETE': ['%(app_label)s.delete_%(model_name)s'], + "GET": [], + "OPTIONS": [], + "HEAD": [], + "POST": ["%(app_label)s.add_%(model_name)s"], + "PUT": ["%(app_label)s.change_%(model_name)s"], + "PATCH": ["%(app_label)s.change_%(model_name)s"], + "DELETE": ["%(app_label)s.delete_%(model_name)s"], } authenticated_users_only = True @@ -196,8 +193,8 @@ class DjangoModelPermissions(BasePermission): codes that the user is required to have. """ kwargs = { - 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.model_name + "app_label": model_cls._meta.app_label, + "model_name": model_cls._meta.model_name, } if method not in self.perms_map: @@ -206,16 +203,19 @@ class DjangoModelPermissions(BasePermission): return [perm % kwargs for perm in self.perms_map[method]] def _queryset(self, view): - assert hasattr(view, 'get_queryset') \ - or getattr(view, 'queryset', None) is not None, ( - 'Cannot apply {} on a view that does not set ' - '`.queryset` or have a `.get_queryset()` method.' - ).format(self.__class__.__name__) + assert ( + hasattr(view, "get_queryset") or getattr(view, "queryset", None) is not None + ), ( + "Cannot apply {} on a view that does not set " + "`.queryset` or have a `.get_queryset()` method." + ).format( + self.__class__.__name__ + ) - if hasattr(view, 'get_queryset'): + if hasattr(view, "get_queryset"): queryset = view.get_queryset() - assert queryset is not None, ( - '{}.get_queryset() returned None'.format(view.__class__.__name__) + assert queryset is not None, "{}.get_queryset() returned None".format( + view.__class__.__name__ ) return queryset return view.queryset @@ -223,11 +223,12 @@ class DjangoModelPermissions(BasePermission): def has_permission(self, request, view): # Workaround to ensure DjangoModelPermissions are not applied # to the root view when using DefaultRouter. - if getattr(view, '_ignore_model_permissions', False): + if getattr(view, "_ignore_model_permissions", False): return True if not request.user or ( - not request.user.is_authenticated and self.authenticated_users_only): + not request.user.is_authenticated and self.authenticated_users_only + ): return False queryset = self._queryset(view) @@ -241,6 +242,7 @@ class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): Similar to DjangoModelPermissions, except that anonymous users are allowed read-only access. """ + authenticated_users_only = False @@ -255,20 +257,21 @@ class DjangoObjectPermissions(DjangoModelPermissions): This permission can only be applied against view classes that provide a `.queryset` attribute. """ + perms_map = { - 'GET': [], - 'OPTIONS': [], - 'HEAD': [], - 'POST': ['%(app_label)s.add_%(model_name)s'], - 'PUT': ['%(app_label)s.change_%(model_name)s'], - 'PATCH': ['%(app_label)s.change_%(model_name)s'], - 'DELETE': ['%(app_label)s.delete_%(model_name)s'], + "GET": [], + "OPTIONS": [], + "HEAD": [], + "POST": ["%(app_label)s.add_%(model_name)s"], + "PUT": ["%(app_label)s.change_%(model_name)s"], + "PATCH": ["%(app_label)s.change_%(model_name)s"], + "DELETE": ["%(app_label)s.delete_%(model_name)s"], } def get_required_object_permissions(self, method, model_cls): kwargs = { - 'app_label': model_cls._meta.app_label, - 'model_name': model_cls._meta.model_name + "app_label": model_cls._meta.app_label, + "model_name": model_cls._meta.model_name, } if method not in self.perms_map: @@ -294,7 +297,7 @@ class DjangoObjectPermissions(DjangoModelPermissions): # to make another lookup. raise Http404 - read_perms = self.get_required_object_permissions('GET', model_cls) + read_perms = self.get_required_object_permissions("GET", model_cls) if not user.has_perms(read_perms, obj): raise Http404 diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 31c1e7561..ef2dd53f2 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -9,14 +9,16 @@ from django.db.models import Manager from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.utils import six -from django.utils.encoding import ( - python_2_unicode_compatible, smart_text, uri_to_iri -) +from django.utils.encoding import python_2_unicode_compatible, smart_text, uri_to_iri from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ from rest_framework.fields import ( - Field, empty, get_attribute, is_simple_callable, iter_options + Field, + empty, + get_attribute, + is_simple_callable, + iter_options, ) from rest_framework.reverse import reverse from rest_framework.settings import api_settings @@ -28,7 +30,7 @@ def method_overridden(method_name, klass, instance): Determine if a method has been overridden. """ method = getattr(klass, method_name) - default_method = getattr(method, '__func__', method) # Python 3 compat + default_method = getattr(method, "__func__", method) # Python 3 compat return default_method is not getattr(instance, method_name).__func__ @@ -52,13 +54,14 @@ class Hyperlink(six.text_type): We use this for hyperlinked URLs that may render as a named link in some contexts, or render as a plain URL in others. """ + def __new__(self, url, obj): ret = six.text_type.__new__(self, url) ret.obj = obj return ret def __getnewargs__(self): - return(str(self), self.name,) + return (str(self), self.name) @property def name(self): @@ -77,6 +80,7 @@ class PKOnlyObject(object): instance, but still want to return an object with a .pk attribute, in order to keep the same interface as a regular model instance. """ + def __init__(self, pk): self.pk = pk @@ -87,9 +91,19 @@ class PKOnlyObject(object): # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. MANY_RELATION_KWARGS = ( - 'read_only', 'write_only', 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', 'error_messages', 'allow_empty', - 'html_cutoff', 'html_cutoff_text' + "read_only", + "write_only", + "required", + "default", + "initial", + "source", + "label", + "help_text", + "style", + "error_messages", + "allow_empty", + "html_cutoff", + "html_cutoff_text", ) @@ -99,34 +113,34 @@ class RelatedField(Field): html_cutoff_text = None def __init__(self, **kwargs): - self.queryset = kwargs.pop('queryset', self.queryset) + self.queryset = kwargs.pop("queryset", self.queryset) cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF if cutoff_from_settings is not None: cutoff_from_settings = int(cutoff_from_settings) - self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings) + self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings) self.html_cutoff_text = kwargs.pop( - 'html_cutoff_text', - self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + "html_cutoff_text", + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT), ) - if not method_overridden('get_queryset', RelatedField, self): - assert self.queryset is not None or kwargs.get('read_only', None), ( - 'Relational field must provide a `queryset` argument, ' - 'override `get_queryset`, or set read_only=`True`.' + if not method_overridden("get_queryset", RelatedField, self): + assert self.queryset is not None or kwargs.get("read_only", None), ( + "Relational field must provide a `queryset` argument, " + "override `get_queryset`, or set read_only=`True`." ) - assert not (self.queryset is not None and kwargs.get('read_only', None)), ( - 'Relational fields should not provide a `queryset` argument, ' - 'when setting read_only=`True`.' + assert not (self.queryset is not None and kwargs.get("read_only", None)), ( + "Relational fields should not provide a `queryset` argument, " + "when setting read_only=`True`." ) - kwargs.pop('many', None) - kwargs.pop('allow_empty', None) + kwargs.pop("many", None) + kwargs.pop("allow_empty", None) super(RelatedField, self).__init__(**kwargs) def __new__(cls, *args, **kwargs): # We override this method in order to automagically create # `ManyRelatedField` classes instead when `many=True` is set. - if kwargs.pop('many', False): + if kwargs.pop("many", False): return cls.many_init(*args, **kwargs) return super(RelatedField, cls).__new__(cls, *args, **kwargs) @@ -147,7 +161,7 @@ class RelatedField(Field): kwargs['child'] = cls() return CustomManyRelatedField(*args, **kwargs) """ - list_kwargs = {'child_relation': cls(*args, **kwargs)} + list_kwargs = {"child_relation": cls(*args, **kwargs)} for key in kwargs: if key in MANY_RELATION_KWARGS: list_kwargs[key] = kwargs[key] @@ -155,7 +169,7 @@ class RelatedField(Field): def run_validation(self, data=empty): # We force empty strings to None values for relational fields. - if data == '': + if data == "": data = None return super(RelatedField, self).run_validation(data) @@ -201,13 +215,12 @@ class RelatedField(Field): if cutoff is not None: queryset = queryset[:cutoff] - return OrderedDict([ - ( - self.to_representation(item), - self.display_value(item) - ) - for item in queryset - ]) + return OrderedDict( + [ + (self.to_representation(item), self.display_value(item)) + for item in queryset + ] + ) @property def choices(self): @@ -221,7 +234,7 @@ class RelatedField(Field): return iter_options( self.get_choices(cutoff=self.html_cutoff), cutoff=self.html_cutoff, - cutoff_text=self.html_cutoff_text + cutoff_text=self.html_cutoff_text, ) def display_value(self, instance): @@ -235,7 +248,7 @@ class StringRelatedField(RelatedField): """ def __init__(self, **kwargs): - kwargs['read_only'] = True + kwargs["read_only"] = True super(StringRelatedField, self).__init__(**kwargs) def to_representation(self, value): @@ -244,13 +257,13 @@ class StringRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField): default_error_messages = { - 'required': _('This field is required.'), - 'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'), - 'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'), + "required": _("This field is required."), + "does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'), + "incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."), } def __init__(self, **kwargs): - self.pk_field = kwargs.pop('pk_field', None) + self.pk_field = kwargs.pop("pk_field", None) super(PrimaryKeyRelatedField, self).__init__(**kwargs) def use_pk_only_optimization(self): @@ -262,9 +275,9 @@ class PrimaryKeyRelatedField(RelatedField): try: return self.get_queryset().get(pk=data) except ObjectDoesNotExist: - self.fail('does_not_exist', pk_value=data) + self.fail("does_not_exist", pk_value=data) except (TypeError, ValueError): - self.fail('incorrect_type', data_type=type(data).__name__) + self.fail("incorrect_type", data_type=type(data).__name__) def to_representation(self, value): if self.pk_field is not None: @@ -273,24 +286,26 @@ class PrimaryKeyRelatedField(RelatedField): class HyperlinkedRelatedField(RelatedField): - lookup_field = 'pk' + lookup_field = "pk" view_name = None default_error_messages = { - 'required': _('This field is required.'), - 'no_match': _('Invalid hyperlink - No URL match.'), - 'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'), - 'does_not_exist': _('Invalid hyperlink - Object does not exist.'), - 'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'), + "required": _("This field is required."), + "no_match": _("Invalid hyperlink - No URL match."), + "incorrect_match": _("Invalid hyperlink - Incorrect URL match."), + "does_not_exist": _("Invalid hyperlink - Object does not exist."), + "incorrect_type": _( + "Incorrect type. Expected URL string, received {data_type}." + ), } def __init__(self, view_name=None, **kwargs): if view_name is not None: self.view_name = view_name - assert self.view_name is not None, 'The `view_name` argument is required.' - self.lookup_field = kwargs.pop('lookup_field', self.lookup_field) - self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field) - self.format = kwargs.pop('format', None) + assert self.view_name is not None, "The `view_name` argument is required." + self.lookup_field = kwargs.pop("lookup_field", self.lookup_field) + self.lookup_url_kwarg = kwargs.pop("lookup_url_kwarg", self.lookup_field) + self.format = kwargs.pop("format", None) # We include this simply for dependency injection in tests. # We can't add it as a class attributes or it would expect an @@ -300,7 +315,7 @@ class HyperlinkedRelatedField(RelatedField): super(HyperlinkedRelatedField, self).__init__(**kwargs) def use_pk_only_optimization(self): - return self.lookup_field == 'pk' + return self.lookup_field == "pk" def get_object(self, view_name, view_args, view_kwargs): """ @@ -330,7 +345,7 @@ class HyperlinkedRelatedField(RelatedField): attributes are not configured to correctly match the URL conf. """ # Unsaved objects will not yet have a valid URL. - if hasattr(obj, 'pk') and obj.pk in (None, ''): + if hasattr(obj, "pk") and obj.pk in (None, ""): return None lookup_value = getattr(obj, self.lookup_field) @@ -338,25 +353,25 @@ class HyperlinkedRelatedField(RelatedField): return self.reverse(view_name, kwargs=kwargs, request=request, format=format) def to_internal_value(self, data): - request = self.context.get('request', None) + request = self.context.get("request", None) try: - http_prefix = data.startswith(('http:', 'https:')) + http_prefix = data.startswith(("http:", "https:")) except AttributeError: - self.fail('incorrect_type', data_type=type(data).__name__) + self.fail("incorrect_type", data_type=type(data).__name__) if http_prefix: # If needed convert absolute URLs to relative path data = urlparse.urlparse(data).path prefix = get_script_prefix() if data.startswith(prefix): - data = '/' + data[len(prefix):] + data = "/" + data[len(prefix) :] data = uri_to_iri(data) try: match = resolve(data) except Resolver404: - self.fail('no_match') + self.fail("no_match") try: expected_viewname = request.versioning_scheme.get_versioned_viewname( @@ -366,22 +381,22 @@ class HyperlinkedRelatedField(RelatedField): expected_viewname = self.view_name if match.view_name != expected_viewname: - self.fail('incorrect_match') + self.fail("incorrect_match") try: return self.get_object(match.view_name, match.args, match.kwargs) except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError): - self.fail('does_not_exist') + self.fail("does_not_exist") def to_representation(self, value): - assert 'request' in self.context, ( + assert "request" in self.context, ( "`%s` requires the request in the serializer" " context. Add `context={'request': request}` when instantiating " "the serializer." % self.__class__.__name__ ) - request = self.context['request'] - format = self.context.get('format', None) + request = self.context["request"] + format = self.context.get("format", None) # By default use whatever format is given for the current context # unless the target is a different type to the source. @@ -400,13 +415,13 @@ class HyperlinkedRelatedField(RelatedField): url = self.get_url(value, self.view_name, request, format) except NoReverseMatch: msg = ( - 'Could not resolve URL for hyperlinked relationship using ' + "Could not resolve URL for hyperlinked relationship using " 'view name "%s". You may have failed to include the related ' - 'model in your API, or incorrectly configured the ' - '`lookup_field` attribute on this field.' + "model in your API, or incorrectly configured the " + "`lookup_field` attribute on this field." ) - if value in ('', None): - value_string = {'': 'the empty string', None: 'None'}[value] + if value in ("", None): + value_string = {"": "the empty string", None: "None"}[value] msg += ( " WARNING: The value of the field on the model instance " "was %s, which may be why it didn't match any " @@ -429,9 +444,9 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField): """ def __init__(self, view_name=None, **kwargs): - assert view_name is not None, 'The `view_name` argument is required.' - kwargs['read_only'] = True - kwargs['source'] = '*' + assert view_name is not None, "The `view_name` argument is required." + kwargs["read_only"] = True + kwargs["source"] = "*" super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs) def use_pk_only_optimization(self): @@ -445,13 +460,14 @@ class SlugRelatedField(RelatedField): A read-write field that represents the target of the relationship by a unique 'slug' attribute. """ + default_error_messages = { - 'does_not_exist': _('Object with {slug_name}={value} does not exist.'), - 'invalid': _('Invalid value.'), + "does_not_exist": _("Object with {slug_name}={value} does not exist."), + "invalid": _("Invalid value."), } def __init__(self, slug_field=None, **kwargs): - assert slug_field is not None, 'The `slug_field` argument is required.' + assert slug_field is not None, "The `slug_field` argument is required." self.slug_field = slug_field super(SlugRelatedField, self).__init__(**kwargs) @@ -459,9 +475,11 @@ class SlugRelatedField(RelatedField): try: return self.get_queryset().get(**{self.slug_field: data}) except ObjectDoesNotExist: - self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data)) + self.fail( + "does_not_exist", slug_name=self.slug_field, value=smart_text(data) + ) except (TypeError, ValueError): - self.fail('invalid') + self.fail("invalid") def to_representation(self, obj): return getattr(obj, self.slug_field) @@ -479,31 +497,32 @@ class ManyRelatedField(Field): You shouldn't generally need to be using this class directly yourself, and should instead simply set 'many=True' on the relationship. """ + initial = [] default_empty_html = [] default_error_messages = { - 'not_a_list': _('Expected a list of items but got type "{input_type}".'), - 'empty': _('This list may not be empty.') + "not_a_list": _('Expected a list of items but got type "{input_type}".'), + "empty": _("This list may not be empty."), } html_cutoff = None html_cutoff_text = None def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation - self.allow_empty = kwargs.pop('allow_empty', True) + self.allow_empty = kwargs.pop("allow_empty", True) cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF if cutoff_from_settings is not None: cutoff_from_settings = int(cutoff_from_settings) - self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings) + self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings) self.html_cutoff_text = kwargs.pop( - 'html_cutoff_text', - self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + "html_cutoff_text", + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT), ) - assert child_relation is not None, '`child_relation` is a required argument.' + assert child_relation is not None, "`child_relation` is a required argument." super(ManyRelatedField, self).__init__(*args, **kwargs) - self.child_relation.bind(field_name='', parent=self) + self.child_relation.bind(field_name="", parent=self) def get_value(self, dictionary): # We override the default field access in order to support @@ -511,36 +530,30 @@ class ManyRelatedField(Field): if html.is_html_input(dictionary): # Don't return [] if the update is partial if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): + if getattr(self.root, "partial", False): return empty return dictionary.getlist(self.field_name) return dictionary.get(self.field_name, empty) def to_internal_value(self, data): - if isinstance(data, six.text_type) or not hasattr(data, '__iter__'): - self.fail('not_a_list', input_type=type(data).__name__) + if isinstance(data, six.text_type) or not hasattr(data, "__iter__"): + self.fail("not_a_list", input_type=type(data).__name__) if not self.allow_empty and len(data) == 0: - self.fail('empty') + self.fail("empty") - return [ - self.child_relation.to_internal_value(item) - for item in data - ] + return [self.child_relation.to_internal_value(item) for item in data] def get_attribute(self, instance): # Can't have any relationships if not created - if hasattr(instance, 'pk') and instance.pk is None: + if hasattr(instance, "pk") and instance.pk is None: return [] relationship = get_attribute(instance, self.source_attrs) - return relationship.all() if hasattr(relationship, 'all') else relationship + return relationship.all() if hasattr(relationship, "all") else relationship def to_representation(self, iterable): - return [ - self.child_relation.to_representation(value) - for value in iterable - ] + return [self.child_relation.to_representation(value) for value in iterable] def get_choices(self, cutoff=None): return self.child_relation.get_choices(cutoff) @@ -557,5 +570,5 @@ class ManyRelatedField(Field): return iter_options( self.get_choices(cutoff=self.html_cutoff), cutoff=self.html_cutoff, - cutoff_text=self.html_cutoff_text + cutoff_text=self.html_cutoff_text, ) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index f043e6327..176c2b2d2 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -25,8 +25,13 @@ from django.utils.six.moves.urllib import parse as urlparse from rest_framework import VERSION, exceptions, serializers, status from rest_framework.compat import ( - INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema, - pygments_css, yaml + INDENT_SEPARATORS, + LONG_SEPARATORS, + SHORT_SEPARATORS, + coreapi, + coreschema, + pygments_css, + yaml, ) from rest_framework.exceptions import ParseError from rest_framework.request import is_form_media_type, override_method @@ -45,21 +50,23 @@ class BaseRenderer(object): All renderers should extend this class, setting the `media_type` and `format` attributes, and override the `.render()` method. """ + media_type = None format = None - charset = 'utf-8' - render_style = 'text' + charset = "utf-8" + render_style = "text" def render(self, data, accepted_media_type=None, renderer_context=None): - raise NotImplementedError('Renderer class requires .render() to be implemented') + raise NotImplementedError("Renderer class requires .render() to be implemented") class JSONRenderer(BaseRenderer): """ Renderer which serializes to JSON. """ - media_type = 'application/json' - format = 'json' + + media_type = "application/json" + format = "json" encoder_class = encoders.JSONEncoder ensure_ascii = not api_settings.UNICODE_JSON compact = api_settings.COMPACT_JSON @@ -76,15 +83,15 @@ class JSONRenderer(BaseRenderer): # If the media type looks like 'application/json; indent=4', # then pretty print the result. # Note that we coerce `indent=0` into `indent=None`. - base_media_type, params = parse_header(accepted_media_type.encode('ascii')) + base_media_type, params = parse_header(accepted_media_type.encode("ascii")) try: - return zero_as_none(max(min(int(params['indent']), 8), 0)) + return zero_as_none(max(min(int(params["indent"]), 8), 0)) except (KeyError, ValueError, TypeError): pass # If 'indent' is provided in the context, then pretty print the result. # E.g. If we're being called by the BrowsableAPIRenderer. - return renderer_context.get('indent', None) + return renderer_context.get("indent", None) def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -102,9 +109,12 @@ class JSONRenderer(BaseRenderer): separators = INDENT_SEPARATORS ret = json.dumps( - data, cls=self.encoder_class, - indent=indent, ensure_ascii=self.ensure_ascii, - allow_nan=not self.strict, separators=separators + data, + cls=self.encoder_class, + indent=indent, + ensure_ascii=self.ensure_ascii, + allow_nan=not self.strict, + separators=separators, ) # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, @@ -116,8 +126,8 @@ class JSONRenderer(BaseRenderer): # that is a strict javascript subset. If bytes were returned # by json.dumps() then we don't have these characters in any case. # See: http://timelessrepo.com/json-isnt-a-javascript-subset - ret = ret.replace('\u2028', '\\u2028').replace('\u2029', '\\u2029') - return bytes(ret.encode('utf-8')) + ret = ret.replace("\u2028", "\\u2028").replace("\u2029", "\\u2029") + return bytes(ret.encode("utf-8")) return ret @@ -140,14 +150,12 @@ class TemplateHTMLRenderer(BaseRenderer): For pre-rendered HTML, see StaticHTMLRenderer. """ - media_type = 'text/html' - format = 'html' + + media_type = "text/html" + format = "html" template_name = None - exception_template_names = [ - '%(status_code)s.html', - 'api_exception.html' - ] - charset = 'utf-8' + exception_template_names = ["%(status_code)s.html", "api_exception.html"] + charset = "utf-8" def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -160,9 +168,9 @@ class TemplateHTMLRenderer(BaseRenderer): 3. The return result of calling view.get_template_names(). """ renderer_context = renderer_context or {} - view = renderer_context['view'] - request = renderer_context['request'] - response = renderer_context['response'] + view = renderer_context["view"] + request = renderer_context["request"] + response = renderer_context["response"] if response.exception: template = self.get_exception_template(response) @@ -170,7 +178,7 @@ class TemplateHTMLRenderer(BaseRenderer): template_names = self.get_template_names(response, view) template = self.resolve_template(template_names) - if hasattr(self, 'resolve_context'): + if hasattr(self, "resolve_context"): # Fallback for older versions. context = self.resolve_context(data, request, response) else: @@ -181,9 +189,9 @@ class TemplateHTMLRenderer(BaseRenderer): return loader.select_template(template_names) def get_template_context(self, data, renderer_context): - response = renderer_context['response'] + response = renderer_context["response"] if response.exception: - data['status_code'] = response.status_code + data["status_code"] = response.status_code return data def get_template_names(self, response, view): @@ -191,25 +199,27 @@ class TemplateHTMLRenderer(BaseRenderer): return [response.template_name] elif self.template_name: return [self.template_name] - elif hasattr(view, 'get_template_names'): + elif hasattr(view, "get_template_names"): return view.get_template_names() - elif hasattr(view, 'template_name'): + elif hasattr(view, "template_name"): return [view.template_name] raise ImproperlyConfigured( - 'Returned a template response with no `template_name` attribute set on either the view or response' + "Returned a template response with no `template_name` attribute set on either the view or response" ) def get_exception_template(self, response): - template_names = [name % {'status_code': response.status_code} - for name in self.exception_template_names] + template_names = [ + name % {"status_code": response.status_code} + for name in self.exception_template_names + ] try: # Try to find an appropriate error template return self.resolve_template(template_names) except Exception: # Fall back to using eg '404 Not Found' - body = '%d %s' % (response.status_code, response.status_text.title()) - template = engines['django'].from_string(body) + body = "%d %s" % (response.status_code, response.status_text.title()) + template = engines["django"].from_string(body) return template @@ -227,18 +237,19 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): For template rendered HTML, see TemplateHTMLRenderer. """ - media_type = 'text/html' - format = 'html' - charset = 'utf-8' + + media_type = "text/html" + format = "html" + charset = "utf-8" def render(self, data, accepted_media_type=None, renderer_context=None): renderer_context = renderer_context or {} - response = renderer_context.get('response') + response = renderer_context.get("response") if response and response.exception: - request = renderer_context['request'] + request = renderer_context["request"] template = self.get_exception_template(response) - if hasattr(self, 'resolve_context'): + if hasattr(self, "resolve_context"): context = self.resolve_context(data, request, response) else: context = self.get_template_context(data, renderer_context) @@ -258,107 +269,96 @@ class HTMLFormRenderer(BaseRenderer): Note that rendering of field and form errors is not currently supported. """ - media_type = 'text/html' - format = 'form' - charset = 'utf-8' - template_pack = 'rest_framework/vertical/' - base_template = 'form.html' - default_style = ClassLookupDict({ - serializers.Field: { - 'base_template': 'input.html', - 'input_type': 'text' - }, - serializers.EmailField: { - 'base_template': 'input.html', - 'input_type': 'email' - }, - serializers.URLField: { - 'base_template': 'input.html', - 'input_type': 'url' - }, - serializers.IntegerField: { - 'base_template': 'input.html', - 'input_type': 'number' - }, - serializers.FloatField: { - 'base_template': 'input.html', - 'input_type': 'number' - }, - serializers.DateTimeField: { - 'base_template': 'input.html', - 'input_type': 'datetime-local' - }, - serializers.DateField: { - 'base_template': 'input.html', - 'input_type': 'date' - }, - serializers.TimeField: { - 'base_template': 'input.html', - 'input_type': 'time' - }, - serializers.FileField: { - 'base_template': 'input.html', - 'input_type': 'file' - }, - serializers.BooleanField: { - 'base_template': 'checkbox.html' - }, - serializers.ChoiceField: { - 'base_template': 'select.html', # Also valid: 'radio.html' - }, - serializers.MultipleChoiceField: { - 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' - }, - serializers.RelatedField: { - 'base_template': 'select.html', # Also valid: 'radio.html' - }, - serializers.ManyRelatedField: { - 'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' - }, - serializers.Serializer: { - 'base_template': 'fieldset.html' - }, - serializers.ListSerializer: { - 'base_template': 'list_fieldset.html' - }, - serializers.ListField: { - 'base_template': 'list_field.html' - }, - serializers.DictField: { - 'base_template': 'dict_field.html' - }, - serializers.FilePathField: { - 'base_template': 'select.html', - }, - serializers.JSONField: { - 'base_template': 'textarea.html', - }, - }) + media_type = "text/html" + format = "form" + charset = "utf-8" + template_pack = "rest_framework/vertical/" + base_template = "form.html" + + default_style = ClassLookupDict( + { + serializers.Field: {"base_template": "input.html", "input_type": "text"}, + serializers.EmailField: { + "base_template": "input.html", + "input_type": "email", + }, + serializers.URLField: {"base_template": "input.html", "input_type": "url"}, + serializers.IntegerField: { + "base_template": "input.html", + "input_type": "number", + }, + serializers.FloatField: { + "base_template": "input.html", + "input_type": "number", + }, + serializers.DateTimeField: { + "base_template": "input.html", + "input_type": "datetime-local", + }, + serializers.DateField: { + "base_template": "input.html", + "input_type": "date", + }, + serializers.TimeField: { + "base_template": "input.html", + "input_type": "time", + }, + serializers.FileField: { + "base_template": "input.html", + "input_type": "file", + }, + serializers.BooleanField: {"base_template": "checkbox.html"}, + serializers.ChoiceField: { + "base_template": "select.html" # Also valid: 'radio.html' + }, + serializers.MultipleChoiceField: { + "base_template": "select_multiple.html" # Also valid: 'checkbox_multiple.html' + }, + serializers.RelatedField: { + "base_template": "select.html" # Also valid: 'radio.html' + }, + serializers.ManyRelatedField: { + "base_template": "select_multiple.html" # Also valid: 'checkbox_multiple.html' + }, + serializers.Serializer: {"base_template": "fieldset.html"}, + serializers.ListSerializer: {"base_template": "list_fieldset.html"}, + serializers.ListField: {"base_template": "list_field.html"}, + serializers.DictField: {"base_template": "dict_field.html"}, + serializers.FilePathField: {"base_template": "select.html"}, + serializers.JSONField: {"base_template": "textarea.html"}, + } + ) def render_field(self, field, parent_style): if isinstance(field._field, serializers.HiddenField): - return '' + return "" style = dict(self.default_style[field]) style.update(field.style) - if 'template_pack' not in style: - style['template_pack'] = parent_style.get('template_pack', self.template_pack) - style['renderer'] = self + if "template_pack" not in style: + style["template_pack"] = parent_style.get( + "template_pack", self.template_pack + ) + style["renderer"] = self # Get a clone of the field with text-only value representation. field = field.as_form_field() - if style.get('input_type') == 'datetime-local' and isinstance(field.value, six.text_type): - field.value = field.value.rstrip('Z') + if style.get("input_type") == "datetime-local" and isinstance( + field.value, six.text_type + ): + field.value = field.value.rstrip("Z") - if 'template' in style: - template_name = style['template'] + if "template" in style: + template_name = style["template"] else: - template_name = style['template_pack'].strip('/') + '/' + style['base_template'] + template_name = ( + style["template_pack"].strip("/") + "/" + style["base_template"] + ) template = loader.get_template(template_name) - context = {'field': field, 'style': style} + context = {"field": field, "style": style} return template.render(context) def render(self, data, accepted_media_type=None, renderer_context=None): @@ -368,18 +368,15 @@ class HTMLFormRenderer(BaseRenderer): renderer_context = renderer_context or {} form = data.serializer - style = renderer_context.get('style', {}) - if 'template_pack' not in style: - style['template_pack'] = self.template_pack - style['renderer'] = self + style = renderer_context.get("style", {}) + if "template_pack" not in style: + style["template_pack"] = self.template_pack + style["renderer"] = self - template_pack = style['template_pack'].strip('/') - template_name = template_pack + '/' + self.base_template + template_pack = style["template_pack"].strip("/") + template_name = template_pack + "/" + self.base_template template = loader.get_template(template_name) - context = { - 'form': form, - 'style': style - } + context = {"form": form, "style": style} return template.render(context) @@ -387,12 +384,13 @@ class BrowsableAPIRenderer(BaseRenderer): """ HTML renderer used to self-document the API. """ - media_type = 'text/html' - format = 'api' - template = 'rest_framework/api.html' - filter_template = 'rest_framework/filters/base.html' - code_style = 'emacs' - charset = 'utf-8' + + media_type = "text/html" + format = "api" + template = "rest_framework/api.html" + filter_template = "rest_framework/filters/base.html" + code_style = "emacs" + charset = "utf-8" form_renderer_class = HTMLFormRenderer def get_default_renderer(self, view): @@ -400,10 +398,16 @@ class BrowsableAPIRenderer(BaseRenderer): Return an instance of the first valid renderer. (Don't use another documenting renderer.) """ - renderers = [renderer for renderer in view.renderer_classes - if not issubclass(renderer, BrowsableAPIRenderer)] - non_template_renderers = [renderer for renderer in renderers - if not hasattr(renderer, 'get_template_names')] + renderers = [ + renderer + for renderer in view.renderer_classes + if not issubclass(renderer, BrowsableAPIRenderer) + ] + non_template_renderers = [ + renderer + for renderer in renderers + if not hasattr(renderer, "get_template_names") + ] if not renderers: return None @@ -411,23 +415,23 @@ class BrowsableAPIRenderer(BaseRenderer): return non_template_renderers[0]() return renderers[0]() - def get_content(self, renderer, data, - accepted_media_type, renderer_context): + def get_content(self, renderer, data, accepted_media_type, renderer_context): """ Get the content as if it had been rendered by the default non-documenting renderer. """ if not renderer: - return '[No renderers were found]' + return "[No renderers were found]" - renderer_context['indent'] = 4 + renderer_context["indent"] = 4 content = renderer.render(data, accepted_media_type, renderer_context) - render_style = getattr(renderer, 'render_style', 'text') - assert render_style in ['text', 'binary'], 'Expected .render_style ' \ - '"text" or "binary", but got "%s"' % render_style - if render_style == 'binary': - return '[%d bytes of binary content]' % len(content) + render_style = getattr(renderer, "render_style", "text") + assert render_style in ["text", "binary"], ( + "Expected .render_style " '"text" or "binary", but got "%s"' % render_style + ) + if render_style == "binary": + return "[%d bytes of binary content]" % len(content) return content @@ -446,11 +450,13 @@ class BrowsableAPIRenderer(BaseRenderer): return False # Doesn't have permissions return True - def _get_serializer(self, serializer_class, view_instance, request, *args, **kwargs): - kwargs['context'] = { - 'request': request, - 'format': self.format, - 'view': view_instance + def _get_serializer( + self, serializer_class, view_instance, request, *args, **kwargs + ): + kwargs["context"] = { + "request": request, + "format": self.format, + "view": view_instance, } return serializer_class(*args, **kwargs) @@ -462,9 +468,9 @@ class BrowsableAPIRenderer(BaseRenderer): In the absence of the View having an associated form then return None. """ # See issue #2089 for refactoring this. - serializer = getattr(data, 'serializer', None) - if serializer and not getattr(serializer, 'many', False): - instance = getattr(serializer, 'instance', None) + serializer = getattr(data, "serializer", None) + if serializer and not getattr(serializer, "many", False): + instance = getattr(serializer, "instance", None) if isinstance(instance, Page): instance = None else: @@ -475,7 +481,7 @@ class BrowsableAPIRenderer(BaseRenderer): # serializer instance, rather than dynamically creating a new one. if request.method == method and serializer is not None: try: - kwargs = {'data': request.data} + kwargs = {"data": request.data} except ParseError: kwargs = {} existing_serializer = serializer @@ -487,15 +493,14 @@ class BrowsableAPIRenderer(BaseRenderer): if not self.show_form_for_method(view, method, request, instance): return - if method in ('DELETE', 'OPTIONS'): + if method in ("DELETE", "OPTIONS"): return True # Don't actually need to return a form - has_serializer = getattr(view, 'get_serializer', None) - has_serializer_class = getattr(view, 'serializer_class', None) + has_serializer = getattr(view, "get_serializer", None) + has_serializer_class = getattr(view, "serializer_class", None) - if ( - (not has_serializer and not has_serializer_class) or - not any(is_form_media_type(parser.media_type) for parser in view.parser_classes) + if (not has_serializer and not has_serializer_class) or not any( + is_form_media_type(parser.media_type) for parser in view.parser_classes ): return @@ -506,30 +511,36 @@ class BrowsableAPIRenderer(BaseRenderer): pass if has_serializer: - if method in ('PUT', 'PATCH'): + if method in ("PUT", "PATCH"): serializer = view.get_serializer(instance=instance, **kwargs) else: serializer = view.get_serializer(**kwargs) else: # at this point we must have a serializer_class - if method in ('PUT', 'PATCH'): - serializer = self._get_serializer(view.serializer_class, view, - request, instance=instance, **kwargs) + if method in ("PUT", "PATCH"): + serializer = self._get_serializer( + view.serializer_class, + view, + request, + instance=instance, + **kwargs + ) else: - serializer = self._get_serializer(view.serializer_class, view, - request, **kwargs) + serializer = self._get_serializer( + view.serializer_class, view, request, **kwargs + ) return self.render_form_for_serializer(serializer) def render_form_for_serializer(self, serializer): - if hasattr(serializer, 'initial_data'): + if hasattr(serializer, "initial_data"): serializer.is_valid() form_renderer = self.form_renderer_class() return form_renderer.render( serializer.data, self.accepted_media_type, - {'style': {'template_pack': 'rest_framework/horizontal'}} + {"style": {"template_pack": "rest_framework/horizontal"}}, ) def get_raw_data_form(self, data, view, method, request): @@ -539,9 +550,9 @@ class BrowsableAPIRenderer(BaseRenderer): (Which are typically application/x-www-form-urlencoded) """ # See issue #2089 for refactoring this. - serializer = getattr(data, 'serializer', None) - if serializer and not getattr(serializer, 'many', False): - instance = getattr(serializer, 'instance', None) + serializer = getattr(data, "serializer", None) + if serializer and not getattr(serializer, "many", False): + instance = getattr(serializer, "instance", None) if isinstance(instance, Page): instance = None else: @@ -554,12 +565,12 @@ class BrowsableAPIRenderer(BaseRenderer): # If possible, serialize the initial content for the generic form default_parser = view.parser_classes[0] - renderer_class = getattr(default_parser, 'renderer_class', None) - if hasattr(view, 'get_serializer') and renderer_class: + renderer_class = getattr(default_parser, "renderer_class", None) + if hasattr(view, "get_serializer") and renderer_class: # View has a serializer defined and parser class has a # corresponding renderer that can be used to render the data. - if method in ('PUT', 'PATCH'): + if method in ("PUT", "PATCH"): serializer = view.get_serializer(instance=instance) else: serializer = view.get_serializer() @@ -568,7 +579,7 @@ class BrowsableAPIRenderer(BaseRenderer): renderer = renderer_class() accepted = self.accepted_media_type context = self.renderer_context.copy() - context['indent'] = 4 + context["indent"] = 4 # strip HiddenField from output data = serializer.data.copy() @@ -577,7 +588,7 @@ class BrowsableAPIRenderer(BaseRenderer): data.pop(name, None) content = renderer.render(data, accepted, context) # Renders returns bytes, but CharField expects a str. - content = content.decode('utf-8') + content = content.decode("utf-8") else: content = None @@ -589,16 +600,16 @@ class BrowsableAPIRenderer(BaseRenderer): class GenericContentForm(forms.Form): _content_type = forms.ChoiceField( - label='Media type', + label="Media type", choices=choices, initial=initial, - widget=forms.Select(attrs={'data-override': 'content-type'}) + widget=forms.Select(attrs={"data-override": "content-type"}), ) _content = forms.CharField( - label='Content', - widget=forms.Textarea(attrs={'data-override': 'content'}), + label="Content", + widget=forms.Textarea(attrs={"data-override": "content"}), initial=content, - required=False + required=False, ) return GenericContentForm() @@ -608,23 +619,23 @@ class BrowsableAPIRenderer(BaseRenderer): def get_description(self, view, status_code): if status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN): - return '' + return "" return view.get_view_description(html=True) def get_breadcrumbs(self, request): return get_breadcrumbs(request.path, request) def get_extra_actions(self, view): - if hasattr(view, 'get_extra_action_url_map'): + if hasattr(view, "get_extra_action_url_map"): return view.get_extra_action_url_map() return None def get_filter_form(self, data, view, request): - if not hasattr(view, 'get_queryset') or not hasattr(view, 'filter_backends'): + if not hasattr(view, "get_queryset") or not hasattr(view, "filter_backends"): return # Infer if this is a list view or not. - paginator = getattr(view, 'paginator', None) + paginator = getattr(view, "paginator", None) if isinstance(data, list): pass elif paginator is not None and data is not None: @@ -638,7 +649,7 @@ class BrowsableAPIRenderer(BaseRenderer): queryset = view.get_queryset() elements = [] for backend in view.filter_backends: - if hasattr(backend, 'to_html'): + if hasattr(backend, "to_html"): html = backend().to_html(request, queryset, view) if html: elements.append(html) @@ -647,78 +658,76 @@ class BrowsableAPIRenderer(BaseRenderer): return template = loader.get_template(self.filter_template) - context = {'elements': elements} + context = {"elements": elements} return template.render(context) def get_context(self, data, accepted_media_type, renderer_context): """ Returns the context used to render. """ - view = renderer_context['view'] - request = renderer_context['request'] - response = renderer_context['response'] + view = renderer_context["view"] + request = renderer_context["request"] + response = renderer_context["response"] renderer = self.get_default_renderer(view) - raw_data_post_form = self.get_raw_data_form(data, view, 'POST', request) - raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request) - raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request) + raw_data_post_form = self.get_raw_data_form(data, view, "POST", request) + raw_data_put_form = self.get_raw_data_form(data, view, "PUT", request) + raw_data_patch_form = self.get_raw_data_form(data, view, "PATCH", request) raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form response_headers = OrderedDict(sorted(response.items())) - renderer_content_type = '' + renderer_content_type = "" if renderer: - renderer_content_type = '%s' % renderer.media_type + renderer_content_type = "%s" % renderer.media_type if renderer.charset: - renderer_content_type += ' ;%s' % renderer.charset - response_headers['Content-Type'] = renderer_content_type + renderer_content_type += " ;%s" % renderer.charset + response_headers["Content-Type"] = renderer_content_type - if getattr(view, 'paginator', None) and view.paginator.display_page_controls: + if getattr(view, "paginator", None) and view.paginator.display_page_controls: paginator = view.paginator else: paginator = None csrf_cookie_name = settings.CSRF_COOKIE_NAME csrf_header_name = settings.CSRF_HEADER_NAME - if csrf_header_name.startswith('HTTP_'): + if csrf_header_name.startswith("HTTP_"): csrf_header_name = csrf_header_name[5:] - csrf_header_name = csrf_header_name.replace('_', '-') + csrf_header_name = csrf_header_name.replace("_", "-") context = { - 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), - 'code_style': pygments_css(self.code_style), - 'view': view, - 'request': request, - 'response': response, - 'user': request.user, - 'description': self.get_description(view, response.status_code), - 'name': self.get_name(view), - 'version': VERSION, - 'paginator': paginator, - 'breadcrumblist': self.get_breadcrumbs(request), - 'allowed_methods': view.allowed_methods, - 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], - 'response_headers': response_headers, - - 'put_form': self.get_rendered_html_form(data, view, 'PUT', request), - 'post_form': self.get_rendered_html_form(data, view, 'POST', request), - 'delete_form': self.get_rendered_html_form(data, view, 'DELETE', request), - 'options_form': self.get_rendered_html_form(data, view, 'OPTIONS', request), - - 'extra_actions': self.get_extra_actions(view), - - 'filter_form': self.get_filter_form(data, view, request), - - 'raw_data_put_form': raw_data_put_form, - 'raw_data_post_form': raw_data_post_form, - 'raw_data_patch_form': raw_data_patch_form, - 'raw_data_put_or_patch_form': raw_data_put_or_patch_form, - - 'display_edit_forms': bool(response.status_code != 403), - - 'api_settings': api_settings, - 'csrf_cookie_name': csrf_cookie_name, - 'csrf_header_name': csrf_header_name + "content": self.get_content( + renderer, data, accepted_media_type, renderer_context + ), + "code_style": pygments_css(self.code_style), + "view": view, + "request": request, + "response": response, + "user": request.user, + "description": self.get_description(view, response.status_code), + "name": self.get_name(view), + "version": VERSION, + "paginator": paginator, + "breadcrumblist": self.get_breadcrumbs(request), + "allowed_methods": view.allowed_methods, + "available_formats": [ + renderer_cls.format for renderer_cls in view.renderer_classes + ], + "response_headers": response_headers, + "put_form": self.get_rendered_html_form(data, view, "PUT", request), + "post_form": self.get_rendered_html_form(data, view, "POST", request), + "delete_form": self.get_rendered_html_form(data, view, "DELETE", request), + "options_form": self.get_rendered_html_form(data, view, "OPTIONS", request), + "extra_actions": self.get_extra_actions(view), + "filter_form": self.get_filter_form(data, view, request), + "raw_data_put_form": raw_data_put_form, + "raw_data_post_form": raw_data_post_form, + "raw_data_patch_form": raw_data_patch_form, + "raw_data_put_or_patch_form": raw_data_put_or_patch_form, + "display_edit_forms": bool(response.status_code != 403), + "api_settings": api_settings, + "csrf_cookie_name": csrf_cookie_name, + "csrf_header_name": csrf_header_name, } return context @@ -726,17 +735,17 @@ class BrowsableAPIRenderer(BaseRenderer): """ Render the HTML for the browsable API representation. """ - self.accepted_media_type = accepted_media_type or '' + self.accepted_media_type = accepted_media_type or "" self.renderer_context = renderer_context or {} template = loader.get_template(self.template) context = self.get_context(data, accepted_media_type, renderer_context) - ret = template.render(context, request=renderer_context['request']) + ret = template.render(context, request=renderer_context["request"]) # Munge DELETE Response code to allow us to return content # (Do this *after* we've rendered the template so that we include # the normal deletion response code in the output) - response = renderer_context['response'] + response = renderer_context["response"] if response.status_code == status.HTTP_204_NO_CONTENT: response.status_code = status.HTTP_200_OK @@ -744,46 +753,50 @@ class BrowsableAPIRenderer(BaseRenderer): class AdminRenderer(BrowsableAPIRenderer): - template = 'rest_framework/admin.html' - format = 'admin' + template = "rest_framework/admin.html" + format = "admin" def render(self, data, accepted_media_type=None, renderer_context=None): - self.accepted_media_type = accepted_media_type or '' + self.accepted_media_type = accepted_media_type or "" self.renderer_context = renderer_context or {} - response = renderer_context['response'] - request = renderer_context['request'] - view = self.renderer_context['view'] + response = renderer_context["response"] + request = renderer_context["request"] + view = self.renderer_context["view"] if response.status_code == status.HTTP_400_BAD_REQUEST: # Errors still need to display the list or detail information. # The only way we can get at that is to simulate a GET request. - self.error_form = self.get_rendered_html_form(data, view, request.method, request) - self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(request.method, 'Errors') + self.error_form = self.get_rendered_html_form( + data, view, request.method, request + ) + self.error_title = {"POST": "Create", "PUT": "Edit"}.get( + request.method, "Errors" + ) - with override_method(view, request, 'GET') as request: + with override_method(view, request, "GET") as request: response = view.get(request, *view.args, **view.kwargs) data = response.data template = loader.get_template(self.template) context = self.get_context(data, accepted_media_type, renderer_context) - ret = template.render(context, request=renderer_context['request']) + ret = template.render(context, request=renderer_context["request"]) # Creation and deletion should use redirects in the admin style. - if response.status_code == status.HTTP_201_CREATED and 'Location' in response: + if response.status_code == status.HTTP_201_CREATED and "Location" in response: response.status_code = status.HTTP_303_SEE_OTHER - response['Location'] = request.build_absolute_uri() - ret = '' + response["Location"] = request.build_absolute_uri() + ret = "" if response.status_code == status.HTTP_204_NO_CONTENT: response.status_code = status.HTTP_303_SEE_OTHER try: # Attempt to get the parent breadcrumb URL. - response['Location'] = self.get_breadcrumbs(request)[-2][1] + response["Location"] = self.get_breadcrumbs(request)[-2][1] except KeyError: # Otherwise reload current URL to get a 'Not Found' page. - response['Location'] = request.full_path - ret = '' + response["Location"] = request.full_path + ret = "" return ret @@ -795,7 +808,7 @@ class AdminRenderer(BrowsableAPIRenderer): data, accepted_media_type, renderer_context ) - paginator = getattr(context['view'], 'paginator', None) + paginator = getattr(context["view"], "paginator", None) if paginator is not None and data is not None: try: results = paginator.get_results(data) @@ -806,29 +819,29 @@ class AdminRenderer(BrowsableAPIRenderer): if results is None: header = {} - style = 'detail' + style = "detail" elif isinstance(results, list): header = results[0] if results else {} - style = 'list' + style = "list" else: header = results - style = 'detail' + style = "detail" - columns = [key for key in header if key != 'url'] - details = [key for key in header if key != 'url'] + columns = [key for key in header if key != "url"] + details = [key for key in header if key != "url"] - if isinstance(results, list) and 'view' in renderer_context: + if isinstance(results, list) and "view" in renderer_context: for result in results: - url = self.get_result_url(result, context['view']) + url = self.get_result_url(result, context["view"]) if url is not None: - result.setdefault('url', url) + result.setdefault("url", url) - context['style'] = style - context['columns'] = columns - context['details'] = details - context['results'] = results - context['error_form'] = getattr(self, 'error_form', None) - context['error_title'] = getattr(self, 'error_title', None) + context["style"] = style + context["columns"] = columns + context["details"] = details + context["results"] = results + context["error_form"] = getattr(self, "error_form", None) + context["error_title"] = getattr(self, "error_title", None) return context def get_result_url(self, result, view): @@ -838,79 +851,82 @@ class AdminRenderer(BrowsableAPIRenderer): This only works with views that are generic-like (has `.lookup_field`) and viewset-like (has `.basename` / `.reverse_action()`). """ - if not hasattr(view, 'reverse_action') or \ - not hasattr(view, 'lookup_field'): + if not hasattr(view, "reverse_action") or not hasattr(view, "lookup_field"): return lookup_field = view.lookup_field - lookup_url_kwarg = getattr(view, 'lookup_url_kwarg', None) or lookup_field + lookup_url_kwarg = getattr(view, "lookup_url_kwarg", None) or lookup_field try: kwargs = {lookup_url_kwarg: result[lookup_field]} - return view.reverse_action('detail', kwargs=kwargs) + return view.reverse_action("detail", kwargs=kwargs) except (KeyError, NoReverseMatch): return class DocumentationRenderer(BaseRenderer): - media_type = 'text/html' - format = 'html' - charset = 'utf-8' - template = 'rest_framework/docs/index.html' - error_template = 'rest_framework/docs/error.html' - code_style = 'emacs' - languages = ['shell', 'javascript', 'python'] + media_type = "text/html" + format = "html" + charset = "utf-8" + template = "rest_framework/docs/index.html" + error_template = "rest_framework/docs/error.html" + code_style = "emacs" + languages = ["shell", "javascript", "python"] def get_context(self, data, request): return { - 'document': data, - 'langs': self.languages, - 'lang_htmls': ["rest_framework/docs/langs/%s.html" % l for l in self.languages], - 'lang_intro_htmls': ["rest_framework/docs/langs/%s-intro.html" % l for l in self.languages], - 'code_style': pygments_css(self.code_style), - 'request': request + "document": data, + "langs": self.languages, + "lang_htmls": [ + "rest_framework/docs/langs/%s.html" % l for l in self.languages + ], + "lang_intro_htmls": [ + "rest_framework/docs/langs/%s-intro.html" % l for l in self.languages + ], + "code_style": pygments_css(self.code_style), + "request": request, } def render(self, data, accepted_media_type=None, renderer_context=None): if isinstance(data, coreapi.Document): template = loader.get_template(self.template) - context = self.get_context(data, renderer_context['request']) - return template.render(context, request=renderer_context['request']) + context = self.get_context(data, renderer_context["request"]) + return template.render(context, request=renderer_context["request"]) else: template = loader.get_template(self.error_template) context = { "data": data, - "request": renderer_context['request'], - "response": renderer_context['response'], + "request": renderer_context["request"], + "response": renderer_context["response"], "debug": settings.DEBUG, } - return template.render(context, request=renderer_context['request']) + return template.render(context, request=renderer_context["request"]) class SchemaJSRenderer(BaseRenderer): - media_type = 'application/javascript' - format = 'javascript' - charset = 'utf-8' - template = 'rest_framework/schema.js' + media_type = "application/javascript" + format = "javascript" + charset = "utf-8" + template = "rest_framework/schema.js" def render(self, data, accepted_media_type=None, renderer_context=None): codec = coreapi.codecs.CoreJSONCodec() - schema = base64.b64encode(codec.encode(data)).decode('ascii') + schema = base64.b64encode(codec.encode(data)).decode("ascii") template = loader.get_template(self.template) - context = {'schema': mark_safe(schema)} - request = renderer_context['request'] + context = {"schema": mark_safe(schema)} + request = renderer_context["request"] return template.render(context, request=request) class MultiPartRenderer(BaseRenderer): - media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' - format = 'multipart' - charset = 'utf-8' - BOUNDARY = 'BoUnDaRyStRiNg' + media_type = "multipart/form-data; boundary=BoUnDaRyStRiNg" + format = "multipart" + charset = "utf-8" + BOUNDARY = "BoUnDaRyStRiNg" def render(self, data, accepted_media_type=None, renderer_context=None): - if hasattr(data, 'items'): + if hasattr(data, "items"): for key, value in data.items(): assert not isinstance(value, dict), ( "Test data contained a dictionary value for key '%s', " @@ -922,15 +938,15 @@ class MultiPartRenderer(BaseRenderer): class CoreJSONRenderer(BaseRenderer): - media_type = 'application/coreapi+json' + media_type = "application/coreapi+json" charset = None - format = 'corejson' + format = "corejson" def __init__(self): - assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.' + assert coreapi, "Using CoreJSONRenderer, but `coreapi` is not installed." def render(self, data, media_type=None, renderer_context=None): - indent = bool(renderer_context.get('indent', 0)) + indent = bool(renderer_context.get("indent", 0)) codec = coreapi.codecs.CoreJSONCodec() return codec.dump(data, indent=indent) @@ -938,38 +954,35 @@ class CoreJSONRenderer(BaseRenderer): class _BaseOpenAPIRenderer: def get_schema(self, instance): CLASS_TO_TYPENAME = { - coreschema.Object: 'object', - coreschema.Array: 'array', - coreschema.Number: 'number', - coreschema.Integer: 'integer', - coreschema.String: 'string', - coreschema.Boolean: 'boolean', + coreschema.Object: "object", + coreschema.Array: "array", + coreschema.Number: "number", + coreschema.Integer: "integer", + coreschema.String: "string", + coreschema.Boolean: "boolean", } schema = {} if instance.__class__ in CLASS_TO_TYPENAME: - schema['type'] = CLASS_TO_TYPENAME[instance.__class__] - schema['title'] = instance.title - schema['description'] = instance.description - if hasattr(instance, 'enum'): - schema['enum'] = instance.enum + schema["type"] = CLASS_TO_TYPENAME[instance.__class__] + schema["title"] = instance.title + schema["description"] = instance.description + if hasattr(instance, "enum"): + schema["enum"] = instance.enum return schema def get_parameters(self, link): parameters = [] for field in link.fields: - if field.location not in ['path', 'query']: + if field.location not in ["path", "query"]: continue - parameter = { - 'name': field.name, - 'in': field.location, - } + parameter = {"name": field.name, "in": field.location} if field.required: - parameter['required'] = True + parameter["required"] = True if field.description: - parameter['description'] = field.description + parameter["description"] = field.description if field.schema: - parameter['schema'] = self.get_schema(field.schema) + parameter["schema"] = self.get_schema(field.schema) parameters.append(parameter) return parameters @@ -977,17 +990,15 @@ class _BaseOpenAPIRenderer: operation_id = "%s_%s" % (tag, name) if tag else name parameters = self.get_parameters(link) - operation = { - 'operationId': operation_id, - } + operation = {"operationId": operation_id} if link.title: - operation['summary'] = link.title + operation["summary"] = link.title if link.description: - operation['description'] = link.description + operation["description"] = link.description if parameters: - operation['parameters'] = parameters + operation["parameters"] = parameters if tag: - operation['tags'] = [tag] + operation["tags"] = [tag] return operation def get_paths(self, document): @@ -1011,41 +1022,39 @@ class _BaseOpenAPIRenderer: def get_structure(self, data): return { - 'openapi': '3.0.0', - 'info': { - 'version': '', - 'title': data.title, - 'description': data.description + "openapi": "3.0.0", + "info": { + "version": "", + "title": data.title, + "description": data.description, }, - 'servers': [{ - 'url': data.url - }], - 'paths': self.get_paths(data) + "servers": [{"url": data.url}], + "paths": self.get_paths(data), } class OpenAPIRenderer(_BaseOpenAPIRenderer): - media_type = 'application/vnd.oai.openapi' + media_type = "application/vnd.oai.openapi" charset = None - format = 'openapi' + format = "openapi" def __init__(self): - assert coreapi, 'Using OpenAPIRenderer, but `coreapi` is not installed.' - assert yaml, 'Using OpenAPIRenderer, but `pyyaml` is not installed.' + assert coreapi, "Using OpenAPIRenderer, but `coreapi` is not installed." + assert yaml, "Using OpenAPIRenderer, but `pyyaml` is not installed." def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) - return yaml.dump(structure, default_flow_style=False).encode('utf-8') + return yaml.dump(structure, default_flow_style=False).encode("utf-8") class JSONOpenAPIRenderer(_BaseOpenAPIRenderer): - media_type = 'application/vnd.oai.openapi+json' + media_type = "application/vnd.oai.openapi+json" charset = None - format = 'openapi-json' + format = "openapi-json" def __init__(self): - assert coreapi, 'Using JSONOpenAPIRenderer, but `coreapi` is not installed.' + assert coreapi, "Using JSONOpenAPIRenderer, but `coreapi` is not installed." def render(self, data, media_type=None, renderer_context=None): structure = self.get_structure(data) - return json.dumps(structure, indent=4).encode('utf-8') + return json.dumps(structure, indent=4).encode("utf-8") diff --git a/rest_framework/request.py b/rest_framework/request.py index a6d92e2bd..774d177c1 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -30,8 +30,10 @@ def is_form_media_type(media_type): Return True if the media type is a valid form media type. """ base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING)) - return (base_media_type == 'application/x-www-form-urlencoded' or - base_media_type == 'multipart/form-data') + return ( + base_media_type == "application/x-www-form-urlencoded" + or base_media_type == "multipart/form-data" + ) class override_method(object): @@ -49,12 +51,12 @@ class override_method(object): self.view = view self.request = request self.method = method - self.action = getattr(view, 'action', None) + self.action = getattr(view, "action", None) def __enter__(self): self.view.request = clone_request(self.request, self.method) # For viewsets we also set the `.action` attribute. - action_map = getattr(self.view, 'action_map', {}) + action_map = getattr(self.view, "action_map", {}) self.view.action = action_map.get(self.method.lower()) return self.view.request @@ -86,6 +88,7 @@ class Empty(object): Placeholder for unset attributes. Cannot use `None`, as that may be a valid value. """ + pass @@ -98,30 +101,32 @@ def clone_request(request, method): Internal helper method to clone a request, replacing with a different HTTP method. Used for checking permissions against other methods. """ - ret = Request(request=request._request, - parsers=request.parsers, - authenticators=request.authenticators, - negotiator=request.negotiator, - parser_context=request.parser_context) + ret = Request( + request=request._request, + parsers=request.parsers, + authenticators=request.authenticators, + negotiator=request.negotiator, + parser_context=request.parser_context, + ) ret._data = request._data ret._files = request._files ret._full_data = request._full_data ret._content_type = request._content_type ret._stream = request._stream ret.method = method - if hasattr(request, '_user'): + if hasattr(request, "_user"): ret._user = request._user - if hasattr(request, '_auth'): + if hasattr(request, "_auth"): ret._auth = request._auth - if hasattr(request, '_authenticator'): + if hasattr(request, "_authenticator"): ret._authenticator = request._authenticator - if hasattr(request, 'accepted_renderer'): + if hasattr(request, "accepted_renderer"): ret.accepted_renderer = request.accepted_renderer - if hasattr(request, 'accepted_media_type'): + if hasattr(request, "accepted_media_type"): ret.accepted_media_type = request.accepted_media_type - if hasattr(request, 'version'): + if hasattr(request, "version"): ret.version = request.version - if hasattr(request, 'versioning_scheme'): + if hasattr(request, "versioning_scheme"): ret.versioning_scheme = request.versioning_scheme return ret @@ -152,12 +157,19 @@ class Request(object): authenticating the request's user. """ - def __init__(self, request, parsers=None, authenticators=None, - negotiator=None, parser_context=None): + def __init__( + self, + request, + parsers=None, + authenticators=None, + negotiator=None, + parser_context=None, + ): assert isinstance(request, HttpRequest), ( - 'The `request` argument must be an instance of ' - '`django.http.HttpRequest`, not `{}.{}`.' - .format(request.__class__.__module__, request.__class__.__name__) + "The `request` argument must be an instance of " + "`django.http.HttpRequest`, not `{}.{}`.".format( + request.__class__.__module__, request.__class__.__name__ + ) ) self._request = request @@ -173,11 +185,11 @@ class Request(object): if self.parser_context is None: self.parser_context = {} - self.parser_context['request'] = self - self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET + self.parser_context["request"] = self + self.parser_context["encoding"] = request.encoding or settings.DEFAULT_CHARSET - force_user = getattr(request, '_force_auth_user', None) - force_token = getattr(request, '_force_auth_token', None) + force_user = getattr(request, "_force_auth_user", None) + force_token = getattr(request, "_force_auth_token", None) if force_user is not None or force_token is not None: forced_auth = ForcedAuthentication(force_user, force_token) self.authenticators = (forced_auth,) @@ -188,14 +200,14 @@ class Request(object): @property def content_type(self): meta = self._request.META - return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', '')) + return meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) @property def stream(self): """ Returns an object that may be used to stream the request content. """ - if not _hasattr(self, '_stream'): + if not _hasattr(self, "_stream"): self._load_stream() return self._stream @@ -208,7 +220,7 @@ class Request(object): @property def data(self): - if not _hasattr(self, '_full_data'): + if not _hasattr(self, "_full_data"): self._load_data_and_files() return self._full_data @@ -218,7 +230,7 @@ class Request(object): Returns the user associated with the current request, as authenticated by the authentication classes provided to the request. """ - if not hasattr(self, '_user'): + if not hasattr(self, "_user"): with wrap_attributeerrors(): self._authenticate() return self._user @@ -242,7 +254,7 @@ class Request(object): Returns any non-user authentication information associated with the request, such as an authentication token. """ - if not hasattr(self, '_auth'): + if not hasattr(self, "_auth"): with wrap_attributeerrors(): self._authenticate() return self._auth @@ -262,7 +274,7 @@ class Request(object): Return the instance of the authentication instance class that was used to authenticate the request, or `None`. """ - if not hasattr(self, '_authenticator'): + if not hasattr(self, "_authenticator"): with wrap_attributeerrors(): self._authenticate() return self._authenticator @@ -271,7 +283,7 @@ class Request(object): """ Parses the request content into `self.data`. """ - if not _hasattr(self, '_data'): + if not _hasattr(self, "_data"): self._data, self._files = self._parse() if self._files: self._full_data = self._data.copy() @@ -292,7 +304,7 @@ class Request(object): meta = self._request.META try: content_length = int( - meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0)) + meta.get("CONTENT_LENGTH", meta.get("HTTP_CONTENT_LENGTH", 0)) ) except (ValueError, TypeError): content_length = 0 @@ -308,10 +320,7 @@ class Request(object): """ Return True if this requests supports parsing form data. """ - form_media = ( - 'application/x-www-form-urlencoded', - 'multipart/form-data' - ) + form_media = ("application/x-www-form-urlencoded", "multipart/form-data") return any([parser.media_type in form_media for parser in self.parsers]) def _parse(self): @@ -324,7 +333,7 @@ class Request(object): try: stream = self.stream except RawPostDataException: - if not hasattr(self._request, '_post'): + if not hasattr(self._request, "_post"): raise # If request.POST has been accessed in middleware, and a method='POST' # request was made with 'multipart/form-data', then the request stream @@ -335,7 +344,7 @@ class Request(object): if stream is None or media_type is None: if media_type and is_form_media_type(media_type): - empty_data = QueryDict('', encoding=self._request._encoding) + empty_data = QueryDict("", encoding=self._request._encoding) else: empty_data = {} empty_files = MultiValueDict() @@ -353,7 +362,7 @@ class Request(object): # re-raise. Ensures we don't simply repeat the error when # attempting to render the browsable renderer response, or when # logging the request or similar. - self._data = QueryDict('', encoding=self._request._encoding) + self._data = QueryDict("", encoding=self._request._encoding) self._files = MultiValueDict() self._full_data = self._data raise @@ -416,33 +425,33 @@ class Request(object): @property def DATA(self): raise NotImplementedError( - '`request.DATA` has been deprecated in favor of `request.data` ' - 'since version 3.0, and has been fully removed as of version 3.2.' + "`request.DATA` has been deprecated in favor of `request.data` " + "since version 3.0, and has been fully removed as of version 3.2." ) @property def POST(self): # Ensure that request.POST uses our request parsing. - if not _hasattr(self, '_data'): + if not _hasattr(self, "_data"): self._load_data_and_files() if is_form_media_type(self.content_type): return self._data - return QueryDict('', encoding=self._request._encoding) + return QueryDict("", encoding=self._request._encoding) @property def FILES(self): # Leave this one alone for backwards compat with Django's request.FILES # Different from the other two cases, which are not valid property # names on the WSGIRequest class. - if not _hasattr(self, '_files'): + if not _hasattr(self, "_files"): self._load_data_and_files() return self._files @property def QUERY_PARAMS(self): raise NotImplementedError( - '`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` ' - 'since version 3.0, and has been fully removed as of version 3.2.' + "`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` " + "since version 3.0, and has been fully removed as of version 3.2." ) def force_plaintext_errors(self, value): diff --git a/rest_framework/response.py b/rest_framework/response.py index bf0663255..3d476e4e1 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -19,9 +19,15 @@ class Response(SimpleTemplateResponse): arbitrary media types. """ - def __init__(self, data=None, status=None, - template_name=None, headers=None, - exception=False, content_type=None): + def __init__( + self, + data=None, + status=None, + template_name=None, + headers=None, + exception=False, + content_type=None, + ): """ Alters the init arguments slightly. For example, drop 'template_name', and instead use 'data'. @@ -33,9 +39,9 @@ class Response(SimpleTemplateResponse): if isinstance(data, Serializer): msg = ( - 'You passed a Serializer instance as data, but ' - 'probably meant to pass serialized `.data` or ' - '`.error`. representation.' + "You passed a Serializer instance as data, but " + "probably meant to pass serialized `.data` or " + "`.error`. representation." ) raise AssertionError(msg) @@ -50,14 +56,14 @@ class Response(SimpleTemplateResponse): @property def rendered_content(self): - renderer = getattr(self, 'accepted_renderer', None) - accepted_media_type = getattr(self, 'accepted_media_type', None) - context = getattr(self, 'renderer_context', None) + renderer = getattr(self, "accepted_renderer", None) + accepted_media_type = getattr(self, "accepted_media_type", None) + context = getattr(self, "renderer_context", None) assert renderer, ".accepted_renderer not set on Response" assert accepted_media_type, ".accepted_media_type not set on Response" assert context is not None, ".renderer_context not set on Response" - context['response'] = self + context["response"] = self media_type = renderer.media_type charset = renderer.charset @@ -67,18 +73,17 @@ class Response(SimpleTemplateResponse): content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type - self['Content-Type'] = content_type + self["Content-Type"] = content_type ret = renderer.render(self.data, accepted_media_type, context) if isinstance(ret, six.text_type): assert charset, ( - 'renderer returned unicode, and did not specify ' - 'a charset value.' + "renderer returned unicode, and did not specify " "a charset value." ) return bytes(ret.encode(charset)) if not ret: - del self['Content-Type'] + del self["Content-Type"] return ret @@ -88,7 +93,7 @@ class Response(SimpleTemplateResponse): Returns reason text corresponding to our HTTP response status code. Provided for convenience. """ - return responses.get(self.status_code, '') + return responses.get(self.status_code, "") def __getstate__(self): """ @@ -96,10 +101,15 @@ class Response(SimpleTemplateResponse): """ state = super(Response, self).__getstate__() for key in ( - 'accepted_renderer', 'renderer_context', 'resolver_match', - 'client', 'request', 'json', 'wsgi_request' + "accepted_renderer", + "renderer_context", + "resolver_match", + "client", + "request", + "json", + "wsgi_request", ): if key in state: del state[key] - state['_closable_objects'] = [] + state["_closable_objects"] = [] return state diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index e9cf737f1..38d846b16 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,8 +3,7 @@ Provide urlresolver functions that return fully qualified URLs or view names """ from __future__ import unicode_literals -from django.urls import NoReverseMatch -from django.urls import reverse as django_reverse +from django.urls import NoReverseMatch, reverse as django_reverse from django.utils import six from django.utils.functional import lazy @@ -20,9 +19,7 @@ def preserve_builtin_query_params(url, request=None): if request is None: return url - overrides = [ - api_settings.URL_FORMAT_OVERRIDE, - ] + overrides = [api_settings.URL_FORMAT_OVERRIDE] for param in overrides: if param and (param in request.GET): @@ -38,7 +35,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra to the versioning scheme instance, so that the resulting URL can be modified if needed. """ - scheme = getattr(request, 'versioning_scheme', None) + scheme = getattr(request, "versioning_scheme", None) if scheme is not None: try: url = scheme.reverse(viewname, args, kwargs, request, format, **extra) @@ -59,7 +56,7 @@ def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extr """ if format is not None: kwargs = kwargs or {} - kwargs['format'] = format + kwargs["format"] = format url = django_reverse(viewname, args=args, kwargs=kwargs, **extra) if request: return request.build_absolute_uri(url) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 1cacea181..6296b0f9c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -25,9 +25,7 @@ from django.urls import NoReverseMatch from django.utils import six from django.utils.deprecation import RenameMethodsBase -from rest_framework import ( - RemovedInDRF310Warning, RemovedInDRF311Warning, views -) +from rest_framework import RemovedInDRF310Warning, RemovedInDRF311Warning, views from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -35,8 +33,9 @@ from rest_framework.schemas.views import SchemaView from rest_framework.settings import api_settings from rest_framework.urlpatterns import format_suffix_patterns -Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs']) -DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs']) + +Route = namedtuple("Route", ["url", "mapping", "name", "detail", "initkwargs"]) +DynamicRoute = namedtuple("DynamicRoute", ["url", "name", "detail", "initkwargs"]) class DynamicDetailRoute(object): @@ -45,7 +44,8 @@ class DynamicDetailRoute(object): "`DynamicDetailRoute` is deprecated and will be removed in 3.10 " "in favor of `DynamicRoute`, which accepts a `detail` boolean. Use " "`DynamicRoute(url, name, True, initkwargs)` instead.", - RemovedInDRF310Warning, stacklevel=2 + RemovedInDRF310Warning, + stacklevel=2, ) return DynamicRoute(url, name, True, initkwargs) @@ -56,7 +56,8 @@ class DynamicListRoute(object): "`DynamicListRoute` is deprecated and will be removed in 3.10 in " "favor of `DynamicRoute`, which accepts a `detail` boolean. Use " "`DynamicRoute(url, name, False, initkwargs)` instead.", - RemovedInDRF310Warning, stacklevel=2 + RemovedInDRF310Warning, + stacklevel=2, ) return DynamicRoute(url, name, False, initkwargs) @@ -65,8 +66,8 @@ def escape_curly_brackets(url_path): """ Double brackets in regex of url_path for escape string formatting """ - if ('{' and '}') in url_path: - url_path = url_path.replace('{', '{{').replace('}', '}}') + if ("{" and "}") in url_path: + url_path = url_path.replace("{", "{{").replace("}", "}}") return url_path @@ -79,7 +80,7 @@ def flatten(list_of_lists): class RenameRouterMethods(RenameMethodsBase): renamed_methods = ( - ('get_default_base_name', 'get_default_basename', RemovedInDRF311Warning), + ("get_default_base_name", "get_default_basename", RemovedInDRF311Warning), ) @@ -92,8 +93,9 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)): msg = "The `base_name` argument is pending deprecation in favor of `basename`." warnings.warn(msg, RemovedInDRF311Warning, 2) - assert not (basename and base_name), ( - "Do not provide both the `basename` and `base_name` arguments.") + assert not ( + basename and base_name + ), "Do not provide both the `basename` and `base_name` arguments." if basename is None: basename = base_name @@ -103,7 +105,7 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)): self.registry.append((prefix, viewset, basename)) # invalidate the urls cache - if hasattr(self, '_urls'): + if hasattr(self, "_urls"): del self._urls def get_default_basename(self, viewset): @@ -111,17 +113,17 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)): If `basename` is not specified, attempt to automatically determine it from the viewset. """ - raise NotImplementedError('get_default_basename must be overridden') + raise NotImplementedError("get_default_basename must be overridden") def get_urls(self): """ Return a list of URL patterns, given the registered viewsets. """ - raise NotImplementedError('get_urls must be overridden') + raise NotImplementedError("get_urls must be overridden") @property def urls(self): - if not hasattr(self, '_urls'): + if not hasattr(self, "_urls"): self._urls = self.get_urls() return self._urls @@ -131,48 +133,45 @@ class SimpleRouter(BaseRouter): routes = [ # List route. Route( - url=r'^{prefix}{trailing_slash}$', - mapping={ - 'get': 'list', - 'post': 'create' - }, - name='{basename}-list', + url=r"^{prefix}{trailing_slash}$", + mapping={"get": "list", "post": "create"}, + name="{basename}-list", detail=False, - initkwargs={'suffix': 'List'} + initkwargs={"suffix": "List"}, ), # Dynamically generated list routes. Generated using # @action(detail=False) decorator on methods of the viewset. DynamicRoute( - url=r'^{prefix}/{url_path}{trailing_slash}$', - name='{basename}-{url_name}', + url=r"^{prefix}/{url_path}{trailing_slash}$", + name="{basename}-{url_name}", detail=False, - initkwargs={} + initkwargs={}, ), # Detail route. Route( - url=r'^{prefix}/{lookup}{trailing_slash}$', + url=r"^{prefix}/{lookup}{trailing_slash}$", mapping={ - 'get': 'retrieve', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy' + "get": "retrieve", + "put": "update", + "patch": "partial_update", + "delete": "destroy", }, - name='{basename}-detail', + name="{basename}-detail", detail=True, - initkwargs={'suffix': 'Instance'} + initkwargs={"suffix": "Instance"}, ), # Dynamically generated detail routes. Generated using # @action(detail=True) decorator on methods of the viewset. DynamicRoute( - url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$', - name='{basename}-{url_name}', + url=r"^{prefix}/{lookup}/{url_path}{trailing_slash}$", + name="{basename}-{url_name}", detail=True, - initkwargs={} + initkwargs={}, ), ] def __init__(self, trailing_slash=True): - self.trailing_slash = '/' if trailing_slash else '' + self.trailing_slash = "/" if trailing_slash else "" super(SimpleRouter, self).__init__() def get_default_basename(self, viewset): @@ -180,11 +179,13 @@ class SimpleRouter(BaseRouter): If `basename` is not specified, attempt to automatically determine it from the viewset. """ - queryset = getattr(viewset, 'queryset', None) + queryset = getattr(viewset, "queryset", None) - assert queryset is not None, '`basename` argument not specified, and could ' \ - 'not automatically determine the name from the viewset, as ' \ - 'it does not have a `.queryset` attribute.' + assert queryset is not None, ( + "`basename` argument not specified, and could " + "not automatically determine the name from the viewset, as " + "it does not have a `.queryset` attribute." + ) return queryset.model._meta.object_name.lower() @@ -196,18 +197,29 @@ class SimpleRouter(BaseRouter): """ # converting to list as iterables are good for one pass, known host needs to be checked again and again for # different functions. - known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)])) + known_actions = list( + flatten( + [ + route.mapping.values() + for route in self.routes + if isinstance(route, Route) + ] + ) + ) extra_actions = viewset.get_extra_actions() # checking action names against the known actions list not_allowed = [ - action.__name__ for action in extra_actions + action.__name__ + for action in extra_actions if action.__name__ in known_actions ] if not_allowed: - msg = ('Cannot use the @action decorator on the following ' - 'methods, as they are existing routes: %s') - raise ImproperlyConfigured(msg % ', '.join(not_allowed)) + msg = ( + "Cannot use the @action decorator on the following " + "methods, as they are existing routes: %s" + ) + raise ImproperlyConfigured(msg % ", ".join(not_allowed)) # partition detail and list actions detail_actions = [action for action in extra_actions if action.detail] @@ -216,9 +228,13 @@ class SimpleRouter(BaseRouter): routes = [] for route in self.routes: if isinstance(route, DynamicRoute) and route.detail: - routes += [self._get_dynamic_route(route, action) for action in detail_actions] + routes += [ + self._get_dynamic_route(route, action) for action in detail_actions + ] elif isinstance(route, DynamicRoute) and not route.detail: - routes += [self._get_dynamic_route(route, action) for action in list_actions] + routes += [ + self._get_dynamic_route(route, action) for action in list_actions + ] else: routes.append(route) @@ -231,9 +247,9 @@ class SimpleRouter(BaseRouter): url_path = escape_curly_brackets(action.url_path) return Route( - url=route.url.replace('{url_path}', url_path), + url=route.url.replace("{url_path}", url_path), mapping=action.mapping, - name=route.name.replace('{url_name}', action.url_name), + name=route.name.replace("{url_name}", action.url_name), detail=route.detail, initkwargs=initkwargs, ) @@ -250,7 +266,7 @@ class SimpleRouter(BaseRouter): bound_methods[method] = action return bound_methods - def get_lookup_regex(self, viewset, lookup_prefix=''): + def get_lookup_regex(self, viewset, lookup_prefix=""): """ Given a viewset, return the portion of URL regex that is used to match against a single instance. @@ -261,16 +277,16 @@ class SimpleRouter(BaseRouter): https://github.com/alanjds/drf-nested-routers """ - base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})' + base_regex = "(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})" # Use `pk` as default field, unset set. Default regex should not # consume `.json` style suffixes and should break at '/' boundaries. - lookup_field = getattr(viewset, 'lookup_field', 'pk') - lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field - lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+') + lookup_field = getattr(viewset, "lookup_field", "pk") + lookup_url_kwarg = getattr(viewset, "lookup_url_kwarg", None) or lookup_field + lookup_value = getattr(viewset, "lookup_value_regex", "[^/.]+") return base_regex.format( lookup_prefix=lookup_prefix, lookup_url_kwarg=lookup_url_kwarg, - lookup_value=lookup_value + lookup_value=lookup_value, ) def get_urls(self): @@ -292,23 +308,18 @@ class SimpleRouter(BaseRouter): # Build the url pattern regex = route.url.format( - prefix=prefix, - lookup=lookup, - trailing_slash=self.trailing_slash + prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash ) # If there is no prefix, the first part of the url is probably # controlled by project's urls.py and the router is in an app, # so a slash in the beginning will (A) cause Django to give # warnings and (B) generate URLS that will require using '//'. - if not prefix and regex[:2] == '^/': - regex = '^' + regex[2:] + if not prefix and regex[:2] == "^/": + regex = "^" + regex[2:] initkwargs = route.initkwargs.copy() - initkwargs.update({ - 'basename': basename, - 'detail': route.detail, - }) + initkwargs.update({"basename": basename, "detail": route.detail}) view = viewset.as_view(mapping, **initkwargs) name = route.name.format(basename=basename) @@ -321,6 +332,7 @@ class APIRootView(views.APIView): """ The default basic root view for DefaultRouter """ + _ignore_model_permissions = True schema = None # exclude from schema api_root_dict = None @@ -331,14 +343,14 @@ class APIRootView(views.APIView): namespace = request.resolver_match.namespace for key, url_name in self.api_root_dict.items(): if namespace: - url_name = namespace + ':' + url_name + url_name = namespace + ":" + url_name try: ret[key] = reverse( url_name, args=args, kwargs=kwargs, request=request, - format=kwargs.get('format', None) + format=kwargs.get("format", None), ) except NoReverseMatch: # Don't bail out if eg. no list routes exist, only detail routes. @@ -352,17 +364,18 @@ class DefaultRouter(SimpleRouter): The default router extends the SimpleRouter, but also adds in a default API root view, and adds format suffix patterns to the URLs. """ + include_root_view = True include_format_suffixes = True - root_view_name = 'api-root' + root_view_name = "api-root" default_schema_renderers = None APIRootView = APIRootView APISchemaView = SchemaView SchemaGenerator = SchemaGenerator def __init__(self, *args, **kwargs): - if 'root_renderers' in kwargs: - self.root_renderers = kwargs.pop('root_renderers') + if "root_renderers" in kwargs: + self.root_renderers = kwargs.pop("root_renderers") else: self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super(DefaultRouter, self).__init__(*args, **kwargs) @@ -387,7 +400,7 @@ class DefaultRouter(SimpleRouter): if self.include_root_view: view = self.get_api_root_view(api_urls=urls) - root_url = url(r'^$', view, name=self.root_view_name) + root_url = url(r"^$", view, name=self.root_view_name) urls.append(root_url) if self.include_format_suffixes: diff --git a/rest_framework/schemas/__init__.py b/rest_framework/schemas/__init__.py index ba0ec6536..00acae91c 100644 --- a/rest_framework/schemas/__init__.py +++ b/rest_framework/schemas/__init__.py @@ -27,18 +27,29 @@ from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa def get_schema_view( - title=None, url=None, description=None, urlconf=None, renderer_classes=None, - public=False, patterns=None, generator_class=SchemaGenerator, - authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, - permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): + title=None, + url=None, + description=None, + urlconf=None, + renderer_classes=None, + public=False, + patterns=None, + generator_class=SchemaGenerator, + authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, + permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES, +): """ Return a schema view. """ # Avoid import cycle on APIView from .views import SchemaView + generator = generator_class( - title=title, url=url, description=description, - urlconf=urlconf, patterns=patterns, + title=title, + url=url, + description=description, + urlconf=urlconf, + patterns=patterns, ) return SchemaView.as_view( renderer_classes=renderer_classes, diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index db226a6c1..e1b05273b 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -15,7 +15,11 @@ from django.utils import six from rest_framework import exceptions from rest_framework.compat import ( - URLPattern, URLResolver, coreapi, coreschema, get_original_route + URLPattern, + URLResolver, + coreapi, + coreschema, + get_original_route, ) from rest_framework.request import clone_request from rest_framework.settings import api_settings @@ -25,7 +29,7 @@ from .utils import is_list_view def common_path(paths): - split_paths = [path.strip('/').split('/') for path in paths] + split_paths = [path.strip("/").split("/") for path in paths] s1 = min(split_paths) s2 = max(split_paths) common = s1 @@ -33,7 +37,7 @@ def common_path(paths): if c != s2[i]: common = s1[:i] break - return '/' + '/'.join(common) + return "/" + "/".join(common) def get_pk_name(model): @@ -47,7 +51,8 @@ def is_api_view(callback): """ # Avoid import cycle on APIView from rest_framework.views import APIView - cls = getattr(callback, 'cls', None) + + cls = getattr(callback, "cls", None) return (cls is not None) and issubclass(cls, APIView) @@ -78,7 +83,7 @@ class LinkNode(OrderedDict): current_val = self.methods_counter[preferred_key] self.methods_counter[preferred_key] += 1 - key = '{}_{}'.format(preferred_key, current_val) + key = "{}_{}".format(preferred_key, current_val) if key not in self: return key @@ -101,9 +106,7 @@ def insert_into(target, keys, value): target.links.append((keys[-1], value)) except TypeError: msg = INSERT_INTO_COLLISION_FMT.format( - value_url=value.url, - target_url=target.url, - keys=keys + value_url=value.url, target_url=target.url, keys=keys ) raise ValueError(msg) @@ -119,24 +122,25 @@ def distribute_links(obj): def is_custom_action(action): return action not in { - 'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy' + "retrieve", + "list", + "create", + "update", + "partial_update", + "destroy", } def endpoint_ordering(endpoint): path, method, callback = endpoint - method_priority = { - 'GET': 0, - 'POST': 1, - 'PUT': 2, - 'PATCH': 3, - 'DELETE': 4 - }.get(method, 5) + method_priority = {"GET": 0, "POST": 1, "PUT": 2, "PATCH": 3, "DELETE": 4}.get( + method, 5 + ) return (path, method_priority) _PATH_PARAMETER_COMPONENT_RE = re.compile( - r'<(?:(?P[^>:]+):)?(?P\w+)>' + r"<(?:(?P[^>:]+):)?(?P\w+)>" ) @@ -144,6 +148,7 @@ class EndpointEnumerator(object): """ A class to determine the available API endpoints that a project exposes. """ + def __init__(self, patterns=None, urlconf=None): if patterns is None: if urlconf is None: @@ -159,7 +164,7 @@ class EndpointEnumerator(object): self.patterns = patterns - def get_api_endpoints(self, patterns=None, prefix=''): + def get_api_endpoints(self, patterns=None, prefix=""): """ Return a list of all available API endpoints by inspecting the URL conf. """ @@ -180,8 +185,7 @@ class EndpointEnumerator(object): elif isinstance(pattern, URLResolver): nested_endpoints = self.get_api_endpoints( - patterns=pattern.url_patterns, - prefix=path_regex + patterns=pattern.url_patterns, prefix=path_regex ) api_endpoints.extend(nested_endpoints) @@ -196,7 +200,7 @@ class EndpointEnumerator(object): path = simplify_regex(path_regex) # Strip Django 2.0 convertors as they are incompatible with uritemplate format - path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g}', path) + path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r"{\g}", path) return path def should_include_endpoint(self, path, callback): @@ -209,11 +213,11 @@ class EndpointEnumerator(object): if callback.cls.schema is None: return False - if 'schema' in callback.initkwargs: - if callback.initkwargs['schema'] is None: + if "schema" in callback.initkwargs: + if callback.initkwargs["schema"] is None: return False - if path.endswith('.{format}') or path.endswith('.{format}/'): + if path.endswith(".{format}") or path.endswith(".{format}/"): return False # Ignore .json style URLs. return True @@ -222,24 +226,24 @@ class EndpointEnumerator(object): """ Return a list of the valid HTTP methods for this endpoint. """ - if hasattr(callback, 'actions'): + if hasattr(callback, "actions"): actions = set(callback.actions) http_method_names = set(callback.cls.http_method_names) methods = [method.upper() for method in actions & http_method_names] else: methods = callback.cls().allowed_methods - return [method for method in methods if method not in ('OPTIONS', 'HEAD')] + return [method for method in methods if method not in ("OPTIONS", "HEAD")] class SchemaGenerator(object): # Map HTTP methods onto actions. default_mapping = { - 'get': 'retrieve', - 'post': 'create', - 'put': 'update', - 'patch': 'partial_update', - 'delete': 'destroy', + "get": "retrieve", + "post": "create", + "put": "update", + "patch": "partial_update", + "delete": "destroy", } endpoint_inspector_cls = EndpointEnumerator @@ -253,12 +257,14 @@ class SchemaGenerator(object): # Set by 'SCHEMA_COERCE_PATH_PK'. coerce_path_pk = None - def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None): - assert coreapi, '`coreapi` must be installed for schema support.' - assert coreschema, '`coreschema` must be installed for schema support.' + def __init__( + self, title=None, url=None, description=None, patterns=None, urlconf=None + ): + assert coreapi, "`coreapi` must be installed for schema support." + assert coreschema, "`coreschema` must be installed for schema support." - if url and not url.endswith('/'): - url += '/' + if url and not url.endswith("/"): + url += "/" self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK @@ -288,8 +294,7 @@ class SchemaGenerator(object): distribute_links(links) return coreapi.Document( - title=self.title, description=self.description, - url=url, content=links + title=self.title, description=self.description, url=url, content=links ) def get_links(self, request=None): @@ -317,7 +322,7 @@ class SchemaGenerator(object): if not self.has_view_permissions(path, method, view): continue link = view.schema.get_link(path, method, base_url=self.url) - subpath = path[len(prefix):] + subpath = path[len(prefix) :] keys = self.get_keys(subpath, method, view) insert_into(links, keys, link) @@ -342,35 +347,35 @@ class SchemaGenerator(object): """ prefixes = [] for path in paths: - components = path.strip('/').split('/') + components = path.strip("/").split("/") initial_components = [] for component in components: - if '{' in component: + if "{" in component: break initial_components.append(component) - prefix = '/'.join(initial_components[:-1]) + prefix = "/".join(initial_components[:-1]) if not prefix: # We can just break early in the case that there's at least # one URL that doesn't have a path prefix. - return '/' - prefixes.append('/' + prefix + '/') + return "/" + prefixes.append("/" + prefix + "/") return common_path(prefixes) def create_view(self, callback, method, request=None): """ Given a callback, return an actual view instance. """ - view = callback.cls(**getattr(callback, 'initkwargs', {})) + view = callback.cls(**getattr(callback, "initkwargs", {})) view.args = () view.kwargs = {} view.format_kwarg = None view.request = None - view.action_map = getattr(callback, 'actions', None) + view.action_map = getattr(callback, "actions", None) - actions = getattr(callback, 'actions', None) + actions = getattr(callback, "actions", None) if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' + if method == "OPTIONS": + view.action = "metadata" else: view.action = actions.get(method.lower()) @@ -398,14 +403,14 @@ class SchemaGenerator(object): where possible. This is cleaner for an external representation. (Ie. "this is an identifier", not "this is a database primary key") """ - if not self.coerce_path_pk or '{pk}' not in path: + if not self.coerce_path_pk or "{pk}" not in path: return path - model = getattr(getattr(view, 'queryset', None), 'model', None) + model = getattr(getattr(view, "queryset", None), "model", None) if model: field_name = get_pk_name(model) else: - field_name = 'id' - return path.replace('{pk}', '{%s}' % field_name) + field_name = "id" + return path.replace("{pk}", "{%s}" % field_name) # Method for generating the link layout.... @@ -421,20 +426,20 @@ class SchemaGenerator(object): /users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") """ - if hasattr(view, 'action'): + if hasattr(view, "action"): # Viewsets have explicitly named actions. action = view.action else: # Views have no associated action, so we determine one from the method. if is_list_view(subpath, method, view): - action = 'list' + action = "list" else: action = self.default_mapping[method.lower()] named_path_components = [ - component for component - in subpath.strip('/').split('/') - if '{' not in component + component + for component in subpath.strip("/").split("/") + if "{" not in component ] if is_custom_action(action): diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 85142edce..82e75d945 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -21,46 +21,38 @@ from rest_framework.utils import formatting from .utils import is_list_view -header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') + +header_regex = re.compile("^[a-zA-Z][0-9A-Za-z_]*:") def field_to_schema(field): - title = force_text(field.label) if field.label else '' - description = force_text(field.help_text) if field.help_text else '' + title = force_text(field.label) if field.label else "" + description = force_text(field.help_text) if field.help_text else "" if isinstance(field, (serializers.ListSerializer, serializers.ListField)): child_schema = field_to_schema(field.child) return coreschema.Array( - items=child_schema, - title=title, - description=description + items=child_schema, title=title, description=description ) elif isinstance(field, serializers.DictField): - return coreschema.Object( - title=title, - description=description - ) + return coreschema.Object(title=title, description=description) elif isinstance(field, serializers.Serializer): return coreschema.Object( - properties=OrderedDict([ - (key, field_to_schema(value)) - for key, value - in field.fields.items() - ]), + properties=OrderedDict( + [(key, field_to_schema(value)) for key, value in field.fields.items()] + ), title=title, - description=description + description=description, ) elif isinstance(field, serializers.ManyRelatedField): related_field_schema = field_to_schema(field.child_relation) return coreschema.Array( - items=related_field_schema, - title=title, - description=description + items=related_field_schema, title=title, description=description ) elif isinstance(field, serializers.PrimaryKeyRelatedField): schema_cls = coreschema.String - model = getattr(field.queryset, 'model', None) + model = getattr(field.queryset, "model", None) if model is not None: model_field = model._meta.pk if isinstance(model_field, models.AutoField): @@ -72,13 +64,11 @@ def field_to_schema(field): return coreschema.Array( items=coreschema.Enum(enum=list(field.choices)), title=title, - description=description + description=description, ) elif isinstance(field, serializers.ChoiceField): return coreschema.Enum( - enum=list(field.choices), - title=title, - description=description + enum=list(field.choices), title=title, description=description ) elif isinstance(field, serializers.BooleanField): return coreschema.Boolean(title=title, description=description) @@ -87,25 +77,17 @@ def field_to_schema(field): elif isinstance(field, serializers.IntegerField): return coreschema.Integer(title=title, description=description) elif isinstance(field, serializers.DateField): - return coreschema.String( - title=title, - description=description, - format='date' - ) + return coreschema.String(title=title, description=description, format="date") elif isinstance(field, serializers.DateTimeField): return coreschema.String( - title=title, - description=description, - format='date-time' + title=title, description=description, format="date-time" ) elif isinstance(field, serializers.JSONField): return coreschema.Object(title=title, description=description) - if field.style.get('base_template') == 'textarea.html': + if field.style.get("base_template") == "textarea.html": return coreschema.String( - title=title, - description=description, - format='textarea' + title=title, description=description, format="textarea" ) return coreschema.String(title=title, description=description) @@ -113,15 +95,14 @@ def field_to_schema(field): def get_pk_description(model, model_field): if isinstance(model_field, models.AutoField): - value_type = _('unique integer value') + value_type = _("unique integer value") elif isinstance(model_field, models.UUIDField): - value_type = _('UUID string') + value_type = _("UUID string") else: - value_type = _('unique value') + value_type = _("unique value") - return _('A {value_type} identifying this {name}.').format( - value_type=value_type, - name=model._meta.verbose_name, + return _("A {value_type} identifying this {name}.").format( + value_type=value_type, name=model._meta.verbose_name ) @@ -200,6 +181,7 @@ class AutoSchema(ViewInspector): Responsible for per-view introspection and schema generation. """ + def __init__(self, manual_fields=None): """ Parameters: @@ -221,14 +203,14 @@ class AutoSchema(ViewInspector): manual_fields = self.get_manual_fields(path, method) fields = self.update_fields(fields, manual_fields) - if fields and any([field.location in ('form', 'body') for field in fields]): + if fields and any([field.location in ("form", "body") for field in fields]): encoding = self.get_encoding(path, method) else: encoding = None description = self.get_description(path, method) - if base_url and path.startswith('/'): + if base_url and path.startswith("/"): path = path[1:] return coreapi.Link( @@ -236,7 +218,7 @@ class AutoSchema(ViewInspector): action=method.lower(), encoding=encoding, fields=fields, - description=description + description=description, ) def get_description(self, path, method): @@ -248,25 +230,31 @@ class AutoSchema(ViewInspector): """ view = self.view - method_name = getattr(view, 'action', method.lower()) + method_name = getattr(view, "action", method.lower()) method_docstring = getattr(view, method_name, None).__doc__ if method_docstring: # An explicit docstring on the method or action. - return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring))) + return self._get_description_section( + view, method.lower(), formatting.dedent(smart_text(method_docstring)) + ) else: - return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description()) + return self._get_description_section( + view, + getattr(view, "action", method.lower()), + view.get_view_description(), + ) def _get_description_section(self, view, header, description): lines = [line for line in description.splitlines()] - current_section = '' - sections = {'': ''} + current_section = "" + sections = {"": ""} for line in lines: if header_regex.match(line): - current_section, seperator, lead = line.partition(':') + current_section, seperator, lead = line.partition(":") sections[current_section] = lead.strip() else: - sections[current_section] += '\n' + line + sections[current_section] += "\n" + line # TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys` coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES @@ -275,7 +263,7 @@ class AutoSchema(ViewInspector): if header in coerce_method_names: if coerce_method_names[header] in sections: return sections[coerce_method_names[header]].strip() - return sections[''].strip() + return sections[""].strip() def get_path_fields(self, path, method): """ @@ -283,12 +271,12 @@ class AutoSchema(ViewInspector): templated path variables. """ view = self.view - model = getattr(getattr(view, 'queryset', None), 'model', None) + model = getattr(getattr(view, "queryset", None), "model", None) fields = [] for variable in uritemplate.variables(path): - title = '' - description = '' + title = "" + description = "" schema_cls = coreschema.String kwargs = {} if model is not None: @@ -306,16 +294,19 @@ class AutoSchema(ViewInspector): elif model_field is not None and model_field.primary_key: description = get_pk_description(model, model_field) - if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: - kwargs['pattern'] = view.lookup_value_regex + if ( + hasattr(view, "lookup_value_regex") + and view.lookup_field == variable + ): + kwargs["pattern"] = view.lookup_value_regex elif isinstance(model_field, models.AutoField): schema_cls = coreschema.Integer field = coreapi.Field( name=variable, - location='path', + location="path", required=True, - schema=schema_cls(title=title, description=description, **kwargs) + schema=schema_cls(title=title, description=description, **kwargs), ) fields.append(field) @@ -328,28 +319,29 @@ class AutoSchema(ViewInspector): """ view = self.view - if method not in ('PUT', 'PATCH', 'POST'): + if method not in ("PUT", "PATCH", "POST"): return [] - if not hasattr(view, 'get_serializer'): + if not hasattr(view, "get_serializer"): return [] try: serializer = view.get_serializer() except exceptions.APIException: serializer = None - warnings.warn('{}.get_serializer() raised an exception during ' - 'schema generation. Serializer fields will not be ' - 'generated for {} {}.' - .format(view.__class__.__name__, method, path)) + warnings.warn( + "{}.get_serializer() raised an exception during " + "schema generation. Serializer fields will not be " + "generated for {} {}.".format(view.__class__.__name__, method, path) + ) if isinstance(serializer, serializers.ListSerializer): return [ coreapi.Field( - name='data', - location='body', + name="data", + location="body", required=True, - schema=coreschema.Array() + schema=coreschema.Array(), ) ] @@ -361,12 +353,12 @@ class AutoSchema(ViewInspector): if field.read_only or isinstance(field, serializers.HiddenField): continue - required = field.required and method != 'PATCH' + required = field.required and method != "PATCH" field = coreapi.Field( name=field.field_name, - location='form', + location="form", required=required, - schema=field_to_schema(field) + schema=field_to_schema(field), ) fields.append(field) @@ -378,7 +370,7 @@ class AutoSchema(ViewInspector): if not is_list_view(path, method, view): return [] - pagination = getattr(view, 'pagination_class', None) + pagination = getattr(view, "pagination_class", None) if not pagination: return [] @@ -397,11 +389,17 @@ class AutoSchema(ViewInspector): Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore) to allow changes based on user experience. """ - if getattr(self.view, 'filter_backends', None) is None: + if getattr(self.view, "filter_backends", None) is None: return False - if hasattr(self.view, 'action'): - return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"] + if hasattr(self.view, "action"): + return self.view.action in [ + "list", + "retrieve", + "update", + "partial_update", + "destroy", + ] return method.lower() in ["get", "put", "patch", "delete"] @@ -447,18 +445,18 @@ class AutoSchema(ViewInspector): # Core API supports the following request encodings over HTTP... supported_media_types = { - 'application/json', - 'application/x-www-form-urlencoded', - 'multipart/form-data', + "application/json", + "application/x-www-form-urlencoded", + "multipart/form-data", } - parser_classes = getattr(view, 'parser_classes', []) + parser_classes = getattr(view, "parser_classes", []) for parser_class in parser_classes: - media_type = getattr(parser_class, 'media_type', None) + media_type = getattr(parser_class, "media_type", None) if media_type in supported_media_types: return media_type # Raw binary uploads are supported with "application/octet-stream" - if media_type == '*/*': - return 'application/octet-stream' + if media_type == "*/*": + return "application/octet-stream" return None @@ -468,7 +466,8 @@ class ManualSchema(ViewInspector): Allows providing a list of coreapi.Fields, plus an optional description. """ - def __init__(self, fields, description='', encoding=None): + + def __init__(self, fields, description="", encoding=None): """ Parameters: @@ -476,14 +475,16 @@ class ManualSchema(ViewInspector): * `description`: String description for view. Optional. """ super(ManualSchema, self).__init__() - assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" + assert all( + isinstance(f, coreapi.Field) for f in fields + ), "`fields` must be a list of coreapi.Field instances" self._fields = fields self._description = description self._encoding = encoding def get_link(self, path, method, base_url): - if base_url and path.startswith('/'): + if base_url and path.startswith("/"): path = path[1:] return coreapi.Link( @@ -491,21 +492,22 @@ class ManualSchema(ViewInspector): action=method.lower(), encoding=self._encoding, fields=self._fields, - description=self._description + description=self._description, ) class DefaultSchema(ViewInspector): """Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting""" + def __get__(self, instance, owner): result = super(DefaultSchema, self).__get__(instance, owner) if not isinstance(result, DefaultSchema): return result inspector_class = api_settings.DEFAULT_SCHEMA_CLASS - assert issubclass(inspector_class, ViewInspector), ( - "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" - ) + assert issubclass( + inspector_class, ViewInspector + ), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass" inspector = inspector_class() inspector.view = instance return inspector diff --git a/rest_framework/schemas/utils.py b/rest_framework/schemas/utils.py index 76437a20a..3eacc686f 100644 --- a/rest_framework/schemas/utils.py +++ b/rest_framework/schemas/utils.py @@ -10,15 +10,15 @@ def is_list_view(path, method, view): """ Return True if the given path/method appears to represent a list view. """ - if hasattr(view, 'action'): + if hasattr(view, "action"): # Viewsets have an explicitly defined action, which we can inspect. - return view.action == 'list' + return view.action == "list" - if method.lower() != 'get': + if method.lower() != "get": return False if isinstance(view, RetrieveModelMixin): return False - path_components = path.strip('/').split('/') - if path_components and '{' in path_components[-1]: + path_components = path.strip("/").split("/") + if path_components and "{" in path_components[-1]: return False return True diff --git a/rest_framework/schemas/views.py b/rest_framework/schemas/views.py index f5e327a94..e609f733c 100644 --- a/rest_framework/schemas/views.py +++ b/rest_framework/schemas/views.py @@ -21,7 +21,7 @@ class SchemaView(APIView): if self.renderer_classes is None: self.renderer_classes = [ renderers.OpenAPIRenderer, - renderers.CoreJSONRenderer + renderers.CoreJSONRenderer, ] if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: self.renderer_classes += [renderers.BrowsableAPIRenderer] diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 9830edb3f..2cd8cccc0 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -17,12 +17,13 @@ import inspect import traceback from collections import OrderedDict -from django.core.exceptions import ImproperlyConfigured -from django.core.exceptions import ValidationError as DjangoValidationError +from django.core.exceptions import ( + ImproperlyConfigured, + ValidationError as DjangoValidationError, +) from django.db import models from django.db.models import DurationField as ModelDurationField -from django.db.models.fields import Field as DjangoModelField -from django.db.models.fields import FieldDoesNotExist +from django.db.models.fields import Field as DjangoModelField, FieldDoesNotExist from django.utils import six, timezone from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ @@ -33,18 +34,28 @@ from rest_framework.fields import get_error_detail, set_value from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation from rest_framework.utils.field_mapping import ( - ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, - get_relation_kwargs, get_url_kwargs + ClassLookupDict, + get_field_kwargs, + get_nested_relation_kwargs, + get_relation_kwargs, + get_url_kwargs, ) from rest_framework.utils.serializer_helpers import ( - BindingDict, BoundField, JSONBoundField, NestedBoundField, ReturnDict, - ReturnList + BindingDict, + BoundField, + JSONBoundField, + NestedBoundField, + ReturnDict, + ReturnList, ) from rest_framework.validators import ( - UniqueForDateValidator, UniqueForMonthValidator, UniqueForYearValidator, - UniqueTogetherValidator + UniqueForDateValidator, + UniqueForMonthValidator, + UniqueForYearValidator, + UniqueTogetherValidator, ) + # Note: We do the following so that users of the framework can use this style: # # example_field = serializers.CharField(...) @@ -52,37 +63,84 @@ from rest_framework.validators import ( # This helps keep the separation between model fields, form fields, and # serializer fields more explicit. from rest_framework.fields import ( # NOQA # isort:skip - BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField, - DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField, - HiddenField, HStoreField, IPAddressField, ImageField, IntegerField, JSONField, - ListField, ModelField, MultipleChoiceField, NullBooleanField, ReadOnlyField, - RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField, + BooleanField, + CharField, + ChoiceField, + DateField, + DateTimeField, + DecimalField, + DictField, + DurationField, + EmailField, + Field, + FileField, + FilePathField, + FloatField, + HiddenField, + HStoreField, + IPAddressField, + ImageField, + IntegerField, + JSONField, + ListField, + ModelField, + MultipleChoiceField, + NullBooleanField, + ReadOnlyField, + RegexField, + SerializerMethodField, + SlugField, + TimeField, + URLField, + UUIDField, ) from rest_framework.relations import ( # NOQA # isort:skip - HyperlinkedIdentityField, HyperlinkedRelatedField, ManyRelatedField, - PrimaryKeyRelatedField, RelatedField, SlugRelatedField, StringRelatedField, + HyperlinkedIdentityField, + HyperlinkedRelatedField, + ManyRelatedField, + PrimaryKeyRelatedField, + RelatedField, + SlugRelatedField, + StringRelatedField, ) # Non-field imports, but public API from rest_framework.fields import ( # NOQA # isort:skip - CreateOnlyDefault, CurrentUserDefault, SkipField, empty + CreateOnlyDefault, + CurrentUserDefault, + SkipField, + empty, ) from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( - 'read_only', 'write_only', 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', 'error_messages', 'allow_empty', - 'instance', 'data', 'partial', 'context', 'allow_null' + "read_only", + "write_only", + "required", + "default", + "initial", + "source", + "label", + "help_text", + "style", + "error_messages", + "allow_empty", + "instance", + "data", + "partial", + "context", + "allow_null", ) -ALL_FIELDS = '__all__' +ALL_FIELDS = "__all__" # BaseSerializer # -------------- + class BaseSerializer(Field): """ The BaseSerializer class provides a minimal class which may be used @@ -112,15 +170,15 @@ class BaseSerializer(Field): self.instance = instance if data is not empty: self.initial_data = data - self.partial = kwargs.pop('partial', False) - self._context = kwargs.pop('context', {}) - kwargs.pop('many', None) + self.partial = kwargs.pop("partial", False) + self._context = kwargs.pop("context", {}) + kwargs.pop("many", None) super(BaseSerializer, self).__init__(**kwargs) def __new__(cls, *args, **kwargs): # We override this method in order to automagically create # `ListSerializer` classes instead when `many=True` is set. - if kwargs.pop('many', False): + if kwargs.pop("many", False): return cls.many_init(*args, **kwargs) return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) @@ -141,51 +199,52 @@ class BaseSerializer(Field): kwargs['child'] = cls() return CustomListSerializer(*args, **kwargs) """ - allow_empty = kwargs.pop('allow_empty', None) + allow_empty = kwargs.pop("allow_empty", None) child_serializer = cls(*args, **kwargs) - list_kwargs = { - 'child': child_serializer, - } + list_kwargs = {"child": child_serializer} if allow_empty is not None: - list_kwargs['allow_empty'] = allow_empty - list_kwargs.update({ - key: value for key, value in kwargs.items() - if key in LIST_SERIALIZER_KWARGS - }) - meta = getattr(cls, 'Meta', None) - list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) + list_kwargs["allow_empty"] = allow_empty + list_kwargs.update( + { + key: value + for key, value in kwargs.items() + if key in LIST_SERIALIZER_KWARGS + } + ) + meta = getattr(cls, "Meta", None) + list_serializer_class = getattr(meta, "list_serializer_class", ListSerializer) return list_serializer_class(*args, **list_kwargs) def to_internal_value(self, data): - raise NotImplementedError('`to_internal_value()` must be implemented.') + raise NotImplementedError("`to_internal_value()` must be implemented.") def to_representation(self, instance): - raise NotImplementedError('`to_representation()` must be implemented.') + raise NotImplementedError("`to_representation()` must be implemented.") def update(self, instance, validated_data): - raise NotImplementedError('`update()` must be implemented.') + raise NotImplementedError("`update()` must be implemented.") def create(self, validated_data): - raise NotImplementedError('`create()` must be implemented.') + raise NotImplementedError("`create()` must be implemented.") def save(self, **kwargs): - assert not hasattr(self, 'save_object'), ( - 'Serializer `%s.%s` has old-style version 2 `.save_object()` ' - 'that is no longer compatible with REST framework 3. ' - 'Use the new-style `.create()` and `.update()` methods instead.' % - (self.__class__.__module__, self.__class__.__name__) + assert not hasattr(self, "save_object"), ( + "Serializer `%s.%s` has old-style version 2 `.save_object()` " + "that is no longer compatible with REST framework 3. " + "Use the new-style `.create()` and `.update()` methods instead." + % (self.__class__.__module__, self.__class__.__name__) ) - assert hasattr(self, '_errors'), ( - 'You must call `.is_valid()` before calling `.save()`.' - ) + assert hasattr( + self, "_errors" + ), "You must call `.is_valid()` before calling `.save()`." - assert not self.errors, ( - 'You cannot call `.save()` on a serializer with invalid data.' - ) + assert ( + not self.errors + ), "You cannot call `.save()` on a serializer with invalid data." # Guard against incorrect use of `serializer.save(commit=False)` - assert 'commit' not in kwargs, ( + assert "commit" not in kwargs, ( "'commit' is not a valid keyword argument to the 'save()' method. " "If you need to access data before committing to the database then " "inspect 'serializer.validated_data' instead. " @@ -194,44 +253,41 @@ class BaseSerializer(Field): "For example: 'serializer.save(owner=request.user)'.'" ) - assert not hasattr(self, '_data'), ( + assert not hasattr(self, "_data"), ( "You cannot call `.save()` after accessing `serializer.data`." "If you need to access data before committing to the database then " "inspect 'serializer.validated_data' instead. " ) - validated_data = dict( - list(self.validated_data.items()) + - list(kwargs.items()) - ) + validated_data = dict(list(self.validated_data.items()) + list(kwargs.items())) if self.instance is not None: self.instance = self.update(self.instance, validated_data) - assert self.instance is not None, ( - '`update()` did not return an object instance.' - ) + assert ( + self.instance is not None + ), "`update()` did not return an object instance." else: self.instance = self.create(validated_data) - assert self.instance is not None, ( - '`create()` did not return an object instance.' - ) + assert ( + self.instance is not None + ), "`create()` did not return an object instance." return self.instance def is_valid(self, raise_exception=False): - assert not hasattr(self, 'restore_object'), ( - 'Serializer `%s.%s` has old-style version 2 `.restore_object()` ' - 'that is no longer compatible with REST framework 3. ' - 'Use the new-style `.create()` and `.update()` methods instead.' % - (self.__class__.__module__, self.__class__.__name__) + assert not hasattr(self, "restore_object"), ( + "Serializer `%s.%s` has old-style version 2 `.restore_object()` " + "that is no longer compatible with REST framework 3. " + "Use the new-style `.create()` and `.update()` methods instead." + % (self.__class__.__module__, self.__class__.__name__) ) - assert hasattr(self, 'initial_data'), ( - 'Cannot call `.is_valid()` as no `data=` keyword argument was ' - 'passed when instantiating the serializer instance.' + assert hasattr(self, "initial_data"), ( + "Cannot call `.is_valid()` as no `data=` keyword argument was " + "passed when instantiating the serializer instance." ) - if not hasattr(self, '_validated_data'): + if not hasattr(self, "_validated_data"): try: self._validated_data = self.run_validation(self.initial_data) except ValidationError as exc: @@ -247,20 +303,22 @@ class BaseSerializer(Field): @property def data(self): - if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'): + if hasattr(self, "initial_data") and not hasattr(self, "_validated_data"): msg = ( - 'When a serializer is passed a `data` keyword argument you ' - 'must call `.is_valid()` before attempting to access the ' - 'serialized `.data` representation.\n' - 'You should either call `.is_valid()` first, ' - 'or access `.initial_data` instead.' + "When a serializer is passed a `data` keyword argument you " + "must call `.is_valid()` before attempting to access the " + "serialized `.data` representation.\n" + "You should either call `.is_valid()` first, " + "or access `.initial_data` instead." ) raise AssertionError(msg) - if not hasattr(self, '_data'): - if self.instance is not None and not getattr(self, '_errors', None): + if not hasattr(self, "_data"): + if self.instance is not None and not getattr(self, "_errors", None): self._data = self.to_representation(self.instance) - elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None): + elif hasattr(self, "_validated_data") and not getattr( + self, "_errors", None + ): self._data = self.to_representation(self.validated_data) else: self._data = self.get_initial() @@ -268,15 +326,15 @@ class BaseSerializer(Field): @property def errors(self): - if not hasattr(self, '_errors'): - msg = 'You must call `.is_valid()` before accessing `.errors`.' + if not hasattr(self, "_errors"): + msg = "You must call `.is_valid()` before accessing `.errors`." raise AssertionError(msg) return self._errors @property def validated_data(self): - if not hasattr(self, '_validated_data'): - msg = 'You must call `.is_valid()` before accessing `.validated_data`.' + if not hasattr(self, "_validated_data"): + msg = "You must call `.is_valid()` before accessing `.validated_data`." raise AssertionError(msg) return self._validated_data @@ -284,6 +342,7 @@ class BaseSerializer(Field): # Serializer & ListSerializer classes # ----------------------------------- + class SerializerMetaclass(type): """ This metaclass sets a dictionary named `_declared_fields` on the class. @@ -295,26 +354,28 @@ class SerializerMetaclass(type): @classmethod def _get_declared_fields(cls, bases, attrs): - fields = [(field_name, attrs.pop(field_name)) - for field_name, obj in list(attrs.items()) - if isinstance(obj, Field)] + fields = [ + (field_name, attrs.pop(field_name)) + for field_name, obj in list(attrs.items()) + if isinstance(obj, Field) + ] fields.sort(key=lambda x: x[1]._creation_counter) # If this class is subclassing another Serializer, add that Serializer's # fields. Note that we loop over the bases in *reverse*. This is necessary # in order to maintain the correct order of fields. for base in reversed(bases): - if hasattr(base, '_declared_fields'): + if hasattr(base, "_declared_fields"): fields = [ - (field_name, obj) for field_name, obj - in base._declared_fields.items() + (field_name, obj) + for field_name, obj in base._declared_fields.items() if field_name not in attrs ] + fields return OrderedDict(fields) def __new__(cls, name, bases, attrs): - attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) + attrs["_declared_fields"] = cls._get_declared_fields(bases, attrs) return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs) @@ -335,19 +396,15 @@ def as_serializer_error(exc): } elif isinstance(detail, list): # Errors raised as a list are non-field errors. - return { - api_settings.NON_FIELD_ERRORS_KEY: detail - } + return {api_settings.NON_FIELD_ERRORS_KEY: detail} # Errors raised as a string are non-field errors. - return { - api_settings.NON_FIELD_ERRORS_KEY: [detail] - } + return {api_settings.NON_FIELD_ERRORS_KEY: [detail]} @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): default_error_messages = { - 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') + "invalid": _("Invalid data. Expected a dictionary, but got {datatype}.") } @property @@ -358,7 +415,7 @@ class Serializer(BaseSerializer): # `fields` is evaluated lazily. We do this to ensure that we don't # have issues importing modules that use ModelSerializers as fields, # even if Django's app-loading stage has not yet run. - if not hasattr(self, '_fields'): + if not hasattr(self, "_fields"): self._fields = BindingDict(self) for key, value in self.get_fields().items(): self._fields[key] = value @@ -366,16 +423,11 @@ class Serializer(BaseSerializer): @cached_property def _writable_fields(self): - return [ - field for field in self.fields.values() if not field.read_only - ] + return [field for field in self.fields.values() if not field.read_only] @cached_property def _readable_fields(self): - return [ - field for field in self.fields.values() - if not field.write_only - ] + return [field for field in self.fields.values() if not field.write_only] def get_fields(self): """ @@ -391,28 +443,32 @@ class Serializer(BaseSerializer): Returns a list of validator callables. """ # Used by the lazily-evaluated `validators` property. - meta = getattr(self, 'Meta', None) - validators = getattr(meta, 'validators', None) + meta = getattr(self, "Meta", None) + validators = getattr(meta, "validators", None) return list(validators) if validators else [] def get_initial(self): - if hasattr(self, 'initial_data'): + if hasattr(self, "initial_data"): # initial_data may not be a valid type if not isinstance(self.initial_data, Mapping): return OrderedDict() - return OrderedDict([ - (field_name, field.get_value(self.initial_data)) - for field_name, field in self.fields.items() - if (field.get_value(self.initial_data) is not empty) and - not field.read_only - ]) + return OrderedDict( + [ + (field_name, field.get_value(self.initial_data)) + for field_name, field in self.fields.items() + if (field.get_value(self.initial_data) is not empty) + and not field.read_only + ] + ) - return OrderedDict([ - (field.field_name, field.get_initial()) - for field in self.fields.values() - if not field.read_only - ]) + return OrderedDict( + [ + (field.field_name, field.get_initial()) + for field in self.fields.values() + if not field.read_only + ] + ) def get_value(self, dictionary): # We override the default field access in order to support @@ -435,7 +491,7 @@ class Serializer(BaseSerializer): try: self.run_validators(value) value = self.validate(value) - assert value is not None, '.validate() should return the validated data' + assert value is not None, ".validate() should return the validated data" except (ValidationError, DjangoValidationError) as exc: raise ValidationError(detail=as_serializer_error(exc)) @@ -443,8 +499,12 @@ class Serializer(BaseSerializer): def _read_only_defaults(self): fields = [ - field for field in self.fields.values() - if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) + field + for field in self.fields.values() + if (field.read_only) + and (field.default != empty) + and (field.source != "*") + and ("." not in field.source) ] defaults = OrderedDict() @@ -473,19 +533,19 @@ class Serializer(BaseSerializer): Dict of native values <- Dict of primitive datatypes. """ if not isinstance(data, Mapping): - message = self.error_messages['invalid'].format( + message = self.error_messages["invalid"].format( datatype=type(data).__name__ ) - raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] - }, code='invalid') + raise ValidationError( + {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="invalid" + ) ret = OrderedDict() errors = OrderedDict() fields = self._writable_fields for field in fields: - validate_method = getattr(self, 'validate_' + field.field_name, None) + validate_method = getattr(self, "validate_" + field.field_name, None) primitive_value = field.get_value(data) try: validated_value = field.run_validation(primitive_value) @@ -523,7 +583,9 @@ class Serializer(BaseSerializer): # # For related fields with `use_pk_only_optimization` we need to # resolve the pk value. - check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute + check_for_none = ( + attribute.pk if isinstance(attribute, PKOnlyObject) else attribute + ) if check_for_none is None: ret[field.field_name] = None else: @@ -548,7 +610,7 @@ class Serializer(BaseSerializer): def __getitem__(self, key): field = self.fields[key] value = self.data.get(key) - error = self.errors.get(key) if hasattr(self, '_errors') else None + error = self.errors.get(key) if hasattr(self, "_errors") else None if isinstance(field, Serializer): return NestedBoundField(field, value, error) if isinstance(field, JSONField): @@ -566,10 +628,14 @@ class Serializer(BaseSerializer): @property def errors(self): ret = super(Serializer, self).errors - if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null': + if ( + isinstance(ret, list) + and len(ret) == 1 + and getattr(ret[0], "code", None) == "null" + ): # Edge case. Provide a more descriptive error than # "this field may not be null", when no data is passed. - detail = ErrorDetail('No data provided', code='null') + detail = ErrorDetail("No data provided", code="null") ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]} return ReturnDict(ret, serializer=self) @@ -577,29 +643,30 @@ class Serializer(BaseSerializer): # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. + class ListSerializer(BaseSerializer): child = None many = True default_error_messages = { - 'not_a_list': _('Expected a list of items but got type "{input_type}".'), - 'empty': _('This list may not be empty.') + "not_a_list": _('Expected a list of items but got type "{input_type}".'), + "empty": _("This list may not be empty."), } def __init__(self, *args, **kwargs): - self.child = kwargs.pop('child', copy.deepcopy(self.child)) - self.allow_empty = kwargs.pop('allow_empty', True) - assert self.child is not None, '`child` is a required argument.' - assert not inspect.isclass(self.child), '`child` has not been instantiated.' + self.child = kwargs.pop("child", copy.deepcopy(self.child)) + self.allow_empty = kwargs.pop("allow_empty", True) + assert self.child is not None, "`child` is a required argument." + assert not inspect.isclass(self.child), "`child` has not been instantiated." super(ListSerializer, self).__init__(*args, **kwargs) - self.child.bind(field_name='', parent=self) + self.child.bind(field_name="", parent=self) def bind(self, field_name, parent): super(ListSerializer, self).bind(field_name, parent) self.partial = self.parent.partial def get_initial(self): - if hasattr(self, 'initial_data'): + if hasattr(self, "initial_data"): return self.to_representation(self.initial_data) return [] @@ -610,7 +677,9 @@ class ListSerializer(BaseSerializer): # We override the default field access in order to support # lists in HTML forms. if html.is_html_input(dictionary): - return html.parse_html_list(dictionary, prefix=self.field_name, default=empty) + return html.parse_html_list( + dictionary, prefix=self.field_name, default=empty + ) return dictionary.get(self.field_name, empty) def run_validation(self, data=empty): @@ -627,7 +696,7 @@ class ListSerializer(BaseSerializer): try: self.run_validators(value) value = self.validate(value) - assert value is not None, '.validate() should return the validated data' + assert value is not None, ".validate() should return the validated data" except (ValidationError, DjangoValidationError) as exc: raise ValidationError(detail=as_serializer_error(exc)) @@ -641,21 +710,21 @@ class ListSerializer(BaseSerializer): data = html.parse_html_list(data, default=[]) if not isinstance(data, list): - message = self.error_messages['not_a_list'].format( + message = self.error_messages["not_a_list"].format( input_type=type(data).__name__ ) - raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] - }, code='not_a_list') + raise ValidationError( + {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="not_a_list" + ) if not self.allow_empty and len(data) == 0: if self.parent and self.partial: raise SkipField() - message = self.error_messages['empty'] - raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] - }, code='empty') + message = self.error_messages["empty"] + raise ValidationError( + {api_settings.NON_FIELD_ERRORS_KEY: [message]}, code="empty" + ) ret = [] errors = [] @@ -682,9 +751,7 @@ class ListSerializer(BaseSerializer): # so, first get a queryset from the Manager if needed iterable = data.all() if isinstance(data, models.Manager) else data - return [ - self.child.to_representation(item) for item in iterable - ] + return [self.child.to_representation(item) for item in iterable] def validate(self, attrs): return attrs @@ -699,16 +766,14 @@ class ListSerializer(BaseSerializer): ) def create(self, validated_data): - return [ - self.child.create(attrs) for attrs in validated_data - ] + return [self.child.create(attrs) for attrs in validated_data] def save(self, **kwargs): """ Save and return a list of object instances. """ # Guard against incorrect use of `serializer.save(commit=False)` - assert 'commit' not in kwargs, ( + assert "commit" not in kwargs, ( "'commit' is not a valid keyword argument to the 'save()' method. " "If you need to access data before committing to the database then " "inspect 'serializer.validated_data' instead. " @@ -724,26 +789,26 @@ class ListSerializer(BaseSerializer): if self.instance is not None: self.instance = self.update(self.instance, validated_data) - assert self.instance is not None, ( - '`update()` did not return an object instance.' - ) + assert ( + self.instance is not None + ), "`update()` did not return an object instance." else: self.instance = self.create(validated_data) - assert self.instance is not None, ( - '`create()` did not return an object instance.' - ) + assert ( + self.instance is not None + ), "`create()` did not return an object instance." return self.instance def is_valid(self, raise_exception=False): # This implementation is the same as the default, # except that we use lists, rather than dicts, as the empty case. - assert hasattr(self, 'initial_data'), ( - 'Cannot call `.is_valid()` as no `data=` keyword argument was ' - 'passed when instantiating the serializer instance.' + assert hasattr(self, "initial_data"), ( + "Cannot call `.is_valid()` as no `data=` keyword argument was " + "passed when instantiating the serializer instance." ) - if not hasattr(self, '_validated_data'): + if not hasattr(self, "_validated_data"): try: self._validated_data = self.run_validation(self.initial_data) except ValidationError as exc: @@ -771,10 +836,14 @@ class ListSerializer(BaseSerializer): @property def errors(self): ret = super(ListSerializer, self).errors - if isinstance(ret, list) and len(ret) == 1 and getattr(ret[0], 'code', None) == 'null': + if ( + isinstance(ret, list) + and len(ret) == 1 + and getattr(ret[0], "code", None) == "null" + ): # Edge case. Provide a more descriptive error than # "this field may not be null", when no data is passed. - detail = ErrorDetail('No data provided', code='null') + detail = ErrorDetail("No data provided", code="null") ret = {api_settings.NON_FIELD_ERRORS_KEY: [detail]} if isinstance(ret, dict): return ReturnDict(ret, serializer=self) @@ -784,6 +853,7 @@ class ListSerializer(BaseSerializer): # ModelSerializer & HyperlinkedModelSerializer # -------------------------------------------- + def raise_errors_on_nested_writes(method_name, serializer, validated_data): """ Give explicit errors when users attempt to pass writable nested data. @@ -809,18 +879,18 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data): # ... # profile = ProfileSerializer() assert not any( - isinstance(field, BaseSerializer) and - (field.source in validated_data) and - isinstance(validated_data[field.source], (list, dict)) + isinstance(field, BaseSerializer) + and (field.source in validated_data) + and isinstance(validated_data[field.source], (list, dict)) for field in serializer._writable_fields ), ( - 'The `.{method_name}()` method does not support writable nested ' - 'fields by default.\nWrite an explicit `.{method_name}()` method for ' - 'serializer `{module}.{class_name}`, or set `read_only=True` on ' - 'nested serializer fields.'.format( + "The `.{method_name}()` method does not support writable nested " + "fields by default.\nWrite an explicit `.{method_name}()` method for " + "serializer `{module}.{class_name}`, or set `read_only=True` on " + "nested serializer fields.".format( method_name=method_name, module=serializer.__class__.__module__, - class_name=serializer.__class__.__name__ + class_name=serializer.__class__.__name__, ) ) @@ -830,18 +900,18 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data): # ... # address = serializer.CharField('profile.address') assert not any( - '.' in field.source and - (key in validated_data) and - isinstance(validated_data[key], (list, dict)) + "." in field.source + and (key in validated_data) + and isinstance(validated_data[key], (list, dict)) for key, field in serializer.fields.items() ), ( - 'The `.{method_name}()` method does not support writable dotted-source ' - 'fields by default.\nWrite an explicit `.{method_name}()` method for ' - 'serializer `{module}.{class_name}`, or set `read_only=True` on ' - 'dotted-source serializer fields.'.format( + "The `.{method_name}()` method does not support writable dotted-source " + "fields by default.\nWrite an explicit `.{method_name}()` method for " + "serializer `{module}.{class_name}`, or set `read_only=True` on " + "dotted-source serializer fields.".format( method_name=method_name, module=serializer.__class__.__module__, - class_name=serializer.__class__.__name__ + class_name=serializer.__class__.__name__, ) ) @@ -862,6 +932,7 @@ class ModelSerializer(Serializer): you need you should either declare the extra/differing fields explicitly on the serializer class, or simply use a `Serializer` class. """ + serializer_field_mapping = { models.AutoField: IntegerField, models.BigIntegerField: IntegerField, @@ -926,7 +997,7 @@ class ModelSerializer(Serializer): If you want to support writable nested relationships you'll need to write an explicit `.create()` method. """ - raise_errors_on_nested_writes('create', self, validated_data) + raise_errors_on_nested_writes("create", self, validated_data) ModelClass = self.Meta.model @@ -944,19 +1015,19 @@ class ModelSerializer(Serializer): except TypeError: tb = traceback.format_exc() msg = ( - 'Got a `TypeError` when calling `%s.%s.create()`. ' - 'This may be because you have a writable field on the ' - 'serializer class that is not a valid argument to ' - '`%s.%s.create()`. You may need to make the field ' - 'read-only, or override the %s.create() method to handle ' - 'this correctly.\nOriginal exception was:\n %s' % - ( + "Got a `TypeError` when calling `%s.%s.create()`. " + "This may be because you have a writable field on the " + "serializer class that is not a valid argument to " + "`%s.%s.create()`. You may need to make the field " + "read-only, or override the %s.create() method to handle " + "this correctly.\nOriginal exception was:\n %s" + % ( ModelClass.__name__, ModelClass._default_manager.name, ModelClass.__name__, ModelClass._default_manager.name, self.__class__.__name__, - tb + tb, ) ) raise TypeError(msg) @@ -970,7 +1041,7 @@ class ModelSerializer(Serializer): return instance def update(self, instance, validated_data): - raise_errors_on_nested_writes('update', self, validated_data) + raise_errors_on_nested_writes("update", self, validated_data) info = model_meta.get_field_info(instance) # Simply set each attribute on the instance, and then save it. @@ -997,24 +1068,22 @@ class ModelSerializer(Serializer): if self.url_field_name is None: self.url_field_name = api_settings.URL_FIELD_NAME - assert hasattr(self, 'Meta'), ( - 'Class {serializer_class} missing "Meta" attribute'.format( - serializer_class=self.__class__.__name__ - ) + assert hasattr( + self, "Meta" + ), 'Class {serializer_class} missing "Meta" attribute'.format( + serializer_class=self.__class__.__name__ ) - assert hasattr(self.Meta, 'model'), ( - 'Class {serializer_class} missing "Meta.model" attribute'.format( - serializer_class=self.__class__.__name__ - ) + assert hasattr( + self.Meta, "model" + ), 'Class {serializer_class} missing "Meta.model" attribute'.format( + serializer_class=self.__class__.__name__ ) if model_meta.is_abstract_model(self.Meta.model): - raise ValueError( - 'Cannot use ModelSerializer with Abstract Models.' - ) + raise ValueError("Cannot use ModelSerializer with Abstract Models.") declared_fields = copy.deepcopy(self._declared_fields) - model = getattr(self.Meta, 'model') - depth = getattr(self.Meta, 'depth', 0) + model = getattr(self.Meta, "model") + depth = getattr(self.Meta, "depth", 0) if depth is not None: assert depth >= 0, "'depth' may not be negative." @@ -1041,19 +1110,15 @@ class ModelSerializer(Serializer): continue extra_field_kwargs = extra_kwargs.get(field_name, {}) - source = extra_field_kwargs.get('source', '*') - if source == '*': + source = extra_field_kwargs.get("source", "*") + if source == "*": source = field_name # Determine the serializer field class and keyword arguments. - field_class, field_kwargs = self.build_field( - source, info, model, depth - ) + field_class, field_kwargs = self.build_field(source, info, model, depth) # Include any kwargs defined in `Meta.extra_kwargs` - field_kwargs = self.include_extra_kwargs( - field_kwargs, extra_field_kwargs - ) + field_kwargs = self.include_extra_kwargs(field_kwargs, extra_field_kwargs) # Create the serializer field. fields[field_name] = field_class(**field_kwargs) @@ -1072,19 +1137,19 @@ class ModelSerializer(Serializer): set of fields, but also takes into account the `Meta.fields` or `Meta.exclude` options if they have been specified. """ - fields = getattr(self.Meta, 'fields', None) - exclude = getattr(self.Meta, 'exclude', None) + fields = getattr(self.Meta, "fields", None) + exclude = getattr(self.Meta, "exclude", None) if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)): raise TypeError( 'The `fields` option must be a list or tuple or "__all__". ' - 'Got %s.' % type(fields).__name__ + "Got %s." % type(fields).__name__ ) if exclude and not isinstance(exclude, (list, tuple)): raise TypeError( - 'The `exclude` option must be a list or tuple. Got %s.' % - type(exclude).__name__ + "The `exclude` option must be a list or tuple. Got %s." + % type(exclude).__name__ ) assert not (fields and exclude), ( @@ -1115,15 +1180,14 @@ class ModelSerializer(Serializer): # a subset of fields. required_field_names = set(declared_fields) for cls in self.__class__.__bases__: - required_field_names -= set(getattr(cls, '_declared_fields', [])) + required_field_names -= set(getattr(cls, "_declared_fields", [])) for field_name in required_field_names: assert field_name in fields, ( "The field '{field_name}' was declared on serializer " "{serializer_class}, but has not been included in the " "'fields' option.".format( - field_name=field_name, - serializer_class=self.__class__.__name__ + field_name=field_name, serializer_class=self.__class__.__name__ ) ) return fields @@ -1138,10 +1202,8 @@ class ModelSerializer(Serializer): "Cannot both declare the field '{field_name}' and include " "it in the {serializer_class} 'exclude' option. Remove the " "field or, if inherited from a parent serializer, disable " - "with `{field_name} = None`." - .format( - field_name=field_name, - serializer_class=self.__class__.__name__ + "with `{field_name} = None`.".format( + field_name=field_name, serializer_class=self.__class__.__name__ ) ) @@ -1149,8 +1211,7 @@ class ModelSerializer(Serializer): "The field '{field_name}' was included on serializer " "{serializer_class} in the 'exclude' option, but does " "not match any model field.".format( - field_name=field_name, - serializer_class=self.__class__.__name__ + field_name=field_name, serializer_class=self.__class__.__name__ ) ) fields.remove(field_name) @@ -1163,10 +1224,10 @@ class ModelSerializer(Serializer): `Meta.fields` option is not specified. """ return ( - [model_info.pk.name] + - list(declared_fields) + - list(model_info.fields) + - list(model_info.forward_relations) + [model_info.pk.name] + + list(declared_fields) + + list(model_info.fields) + + list(model_info.forward_relations) ) # Methods for constructing serializer fields... @@ -1206,9 +1267,9 @@ class ModelSerializer(Serializer): # Special case to handle when a OneToOneField is also the primary key if model_field.one_to_one and model_field.primary_key: field_class = self.serializer_related_field - field_kwargs['queryset'] = model_field.related_model.objects + field_kwargs["queryset"] = model_field.related_model.objects - if 'choices' in field_kwargs: + if "choices" in field_kwargs: # Fields with choices get coerced into `ChoiceField` # instead of using their regular typed field. field_class = self.serializer_choice_field @@ -1216,11 +1277,20 @@ class ModelSerializer(Serializer): # for the choice field. We need to strip these out. # Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES) valid_kwargs = { - 'read_only', 'write_only', - 'required', 'default', 'initial', 'source', - 'label', 'help_text', 'style', - 'error_messages', 'validators', 'allow_null', 'allow_blank', - 'choices' + "read_only", + "write_only", + "required", + "default", + "initial", + "source", + "label", + "help_text", + "style", + "error_messages", + "validators", + "allow_null", + "allow_blank", + "choices", } for key in list(field_kwargs): if key not in valid_kwargs: @@ -1230,20 +1300,22 @@ class ModelSerializer(Serializer): # `model_field` is only valid for the fallback case of # `ModelField`, which is used when no other typed field # matched to the model field. - field_kwargs.pop('model_field', None) + field_kwargs.pop("model_field", None) - if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField): + if not issubclass(field_class, CharField) and not issubclass( + field_class, ChoiceField + ): # `allow_blank` is only valid for textual fields. - field_kwargs.pop('allow_blank', None) + field_kwargs.pop("allow_blank", None) if postgres_fields and isinstance(model_field, postgres_fields.ArrayField): # Populate the `child` argument on `ListField` instances generated # for the PostgreSQL specific `ArrayField`. child_model_field = model_field.base_field child_field_class, child_field_kwargs = self.build_standard_field( - 'child', child_model_field + "child", child_model_field ) - field_kwargs['child'] = child_field_class(**child_field_kwargs) + field_kwargs["child"] = child_field_class(**child_field_kwargs) return field_class, field_kwargs @@ -1254,14 +1326,18 @@ class ModelSerializer(Serializer): field_class = self.serializer_related_field field_kwargs = get_relation_kwargs(field_name, relation_info) - to_field = field_kwargs.pop('to_field', None) - if to_field and not relation_info.reverse and not relation_info.related_model._meta.get_field(to_field).primary_key: - field_kwargs['slug_field'] = to_field + to_field = field_kwargs.pop("to_field", None) + if ( + to_field + and not relation_info.reverse + and not relation_info.related_model._meta.get_field(to_field).primary_key + ): + field_kwargs["slug_field"] = to_field field_class = self.serializer_related_to_field # `view_name` is only valid for hyperlinked relationships. if not issubclass(field_class, HyperlinkedRelatedField): - field_kwargs.pop('view_name', None) + field_kwargs.pop("view_name", None) return field_class, field_kwargs @@ -1269,11 +1345,12 @@ class ModelSerializer(Serializer): """ Create nested fields for forward and reverse relationships. """ + class NestedSerializer(ModelSerializer): class Meta: model = relation_info.related_model depth = nested_depth - 1 - fields = '__all__' + fields = "__all__" field_class = NestedSerializer field_kwargs = get_nested_relation_kwargs(relation_info) @@ -1303,8 +1380,8 @@ class ModelSerializer(Serializer): Raise an error on any unknown fields. """ raise ImproperlyConfigured( - 'Field name `%s` is not valid for model `%s`.' % - (field_name, model_class.__name__) + "Field name `%s` is not valid for model `%s`." + % (field_name, model_class.__name__) ) def include_extra_kwargs(self, kwargs, extra_kwargs): @@ -1312,19 +1389,28 @@ class ModelSerializer(Serializer): Include any 'extra_kwargs' that have been included for this field, possibly removing any incompatible existing keyword arguments. """ - if extra_kwargs.get('read_only', False): + if extra_kwargs.get("read_only", False): for attr in [ - 'required', 'default', 'allow_blank', 'allow_null', - 'min_length', 'max_length', 'min_value', 'max_value', - 'validators', 'queryset' + "required", + "default", + "allow_blank", + "allow_null", + "min_length", + "max_length", + "min_value", + "max_value", + "validators", + "queryset", ]: kwargs.pop(attr, None) - if extra_kwargs.get('default') and kwargs.get('required') is False: - kwargs.pop('required') + if extra_kwargs.get("default") and kwargs.get("required") is False: + kwargs.pop("required") - if extra_kwargs.get('read_only', kwargs.get('read_only', False)): - extra_kwargs.pop('required', None) # Read only fields should always omit the 'required' argument. + if extra_kwargs.get("read_only", kwargs.get("read_only", False)): + extra_kwargs.pop( + "required", None + ) # Read only fields should always omit the 'required' argument. kwargs.update(extra_kwargs) @@ -1337,27 +1423,27 @@ class ModelSerializer(Serializer): Return a dictionary mapping field names to a dictionary of additional keyword arguments. """ - extra_kwargs = copy.deepcopy(getattr(self.Meta, 'extra_kwargs', {})) + extra_kwargs = copy.deepcopy(getattr(self.Meta, "extra_kwargs", {})) - read_only_fields = getattr(self.Meta, 'read_only_fields', None) + read_only_fields = getattr(self.Meta, "read_only_fields", None) if read_only_fields is not None: if not isinstance(read_only_fields, (list, tuple)): raise TypeError( - 'The `read_only_fields` option must be a list or tuple. ' - 'Got %s.' % type(read_only_fields).__name__ + "The `read_only_fields` option must be a list or tuple. " + "Got %s." % type(read_only_fields).__name__ ) for field_name in read_only_fields: kwargs = extra_kwargs.get(field_name, {}) - kwargs['read_only'] = True + kwargs["read_only"] = True extra_kwargs[field_name] = kwargs else: # Guard against the possible misspelling `readonly_fields` (used # by the Django admin and others). - assert not hasattr(self.Meta, 'readonly_fields'), ( - 'Serializer `%s.%s` has field `readonly_fields`; ' - 'the correct spelling for the option is `read_only_fields`.' % - (self.__class__.__module__, self.__class__.__name__) + assert not hasattr(self.Meta, "readonly_fields"), ( + "Serializer `%s.%s` has field `readonly_fields`; " + "the correct spelling for the option is `read_only_fields`." + % (self.__class__.__module__, self.__class__.__name__) ) return extra_kwargs @@ -1370,10 +1456,10 @@ class ModelSerializer(Serializer): ('dict of updated extra kwargs', 'mapping of hidden fields') """ - if getattr(self.Meta, 'validators', None) is not None: + if getattr(self.Meta, "validators", None) is not None: return (extra_kwargs, {}) - model = getattr(self.Meta, 'model') + model = getattr(self.Meta, "model") model_fields = self._get_model_fields( field_names, declared_fields, extra_kwargs ) @@ -1385,8 +1471,11 @@ class ModelSerializer(Serializer): for model_field in model_fields.values(): # Include each of the `unique_for_*` field names. - unique_constraint_names |= {model_field.unique_for_date, model_field.unique_for_month, - model_field.unique_for_year} + unique_constraint_names |= { + model_field.unique_for_date, + model_field.unique_for_month, + model_field.unique_for_year, + } unique_constraint_names -= {None} @@ -1407,9 +1496,9 @@ class ModelSerializer(Serializer): # Get the model field that is referred too. unique_constraint_field = model._meta.get_field(unique_constraint_name) - if getattr(unique_constraint_field, 'auto_now_add', None): + if getattr(unique_constraint_field, "auto_now_add", None): default = CreateOnlyDefault(timezone.now) - elif getattr(unique_constraint_field, 'auto_now', None): + elif getattr(unique_constraint_field, "auto_now", None): default = timezone.now elif unique_constraint_field.has_default(): default = unique_constraint_field.default @@ -1419,9 +1508,11 @@ class ModelSerializer(Serializer): if unique_constraint_name in model_fields: # The corresponding field is present in the serializer if default is empty: - uniqueness_extra_kwargs[unique_constraint_name] = {'required': True} + uniqueness_extra_kwargs[unique_constraint_name] = {"required": True} else: - uniqueness_extra_kwargs[unique_constraint_name] = {'default': default} + uniqueness_extra_kwargs[unique_constraint_name] = { + "default": default + } elif default is not empty: # The corresponding field is not present in the # serializer. We have a default to use for it, so @@ -1443,7 +1534,7 @@ class ModelSerializer(Serializer): Returned as a dict of 'model field name' -> 'model field'. Used internally by `get_uniqueness_field_options`. """ - model = getattr(self.Meta, 'model') + model = getattr(self.Meta, "model") model_fields = {} for field_name in field_names: @@ -1453,11 +1544,11 @@ class ModelSerializer(Serializer): source = field.source or field_name else: try: - source = extra_kwargs[field_name]['source'] + source = extra_kwargs[field_name]["source"] except KeyError: source = field_name - if '.' in source or source == '*': + if "." in source or source == "*": # Model fields will always have a simple source mapping, # they can't be nested attribute lookups. continue @@ -1478,23 +1569,22 @@ class ModelSerializer(Serializer): Determine the set of validators to use when instantiating serializer. """ # If the validators have been declared explicitly then use that. - validators = getattr(getattr(self, 'Meta', None), 'validators', None) + validators = getattr(getattr(self, "Meta", None), "validators", None) if validators is not None: return list(validators) # Otherwise use the default set of validators. return ( - self.get_unique_together_validators() + - self.get_unique_for_date_validators() + self.get_unique_together_validators() + + self.get_unique_for_date_validators() ) def get_unique_together_validators(self): """ Determine a default set of validators for any unique_together constraints. """ - model_class_inheritance_tree = ( - [self.Meta.model] + - list(self.Meta.model._meta.parents) + model_class_inheritance_tree = [self.Meta.model] + list( + self.Meta.model._meta.parents ) # The field names we're passing though here only include fields @@ -1502,14 +1592,19 @@ class ModelSerializer(Serializer): # cannot map to a field, and must be a traversal, so we're not # including those. field_names = { - field.source for field in self._writable_fields - if (field.source != '*') and ('.' not in field.source) + field.source + for field in self._writable_fields + if (field.source != "*") and ("." not in field.source) } # Special Case: Add read_only fields with defaults. field_names |= { - field.source for field in self.fields.values() - if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) + field.source + for field in self.fields.values() + if (field.read_only) + and (field.default != empty) + and (field.source != "*") + and ("." not in field.source) } # Note that we make sure to check `unique_together` both on the @@ -1519,8 +1614,7 @@ class ModelSerializer(Serializer): for unique_together in parent_class._meta.unique_together: if field_names.issuperset(set(unique_together)): validator = UniqueTogetherValidator( - queryset=parent_class._default_manager, - fields=unique_together + queryset=parent_class._default_manager, fields=unique_together ) validators.append(validator) return validators @@ -1544,7 +1638,7 @@ class ModelSerializer(Serializer): validator = UniqueForDateValidator( queryset=default_manager, field=field_name, - date_field=field.unique_for_date + date_field=field.unique_for_date, ) validators.append(validator) @@ -1552,7 +1646,7 @@ class ModelSerializer(Serializer): validator = UniqueForMonthValidator( queryset=default_manager, field=field_name, - date_field=field.unique_for_month + date_field=field.unique_for_month, ) validators.append(validator) @@ -1560,18 +1654,18 @@ class ModelSerializer(Serializer): validator = UniqueForYearValidator( queryset=default_manager, field=field_name, - date_field=field.unique_for_year + date_field=field.unique_for_year, ) validators.append(validator) return validators -if hasattr(models, 'UUIDField'): +if hasattr(models, "UUIDField"): ModelSerializer.serializer_field_mapping[models.UUIDField] = UUIDField # IPAddressField is deprecated in Django -if hasattr(models, 'IPAddressField'): +if hasattr(models, "IPAddressField"): ModelSerializer.serializer_field_mapping[models.IPAddressField] = IPAddressField if postgres_fields: @@ -1588,6 +1682,7 @@ class HyperlinkedModelSerializer(ModelSerializer): * A 'url' field is included instead of the 'id' field. * Relationships to other instances are hyperlinks, instead of primary keys. """ + serializer_related_field = HyperlinkedRelatedField def get_default_field_names(self, declared_fields, model_info): @@ -1596,21 +1691,22 @@ class HyperlinkedModelSerializer(ModelSerializer): `Meta.fields` option is not specified. """ return ( - [self.url_field_name] + - list(declared_fields) + - list(model_info.fields) + - list(model_info.forward_relations) + [self.url_field_name] + + list(declared_fields) + + list(model_info.fields) + + list(model_info.forward_relations) ) def build_nested_field(self, field_name, relation_info, nested_depth): """ Create nested fields for forward and reverse relationships. """ + class NestedSerializer(HyperlinkedModelSerializer): class Meta: model = relation_info.related_model depth = nested_depth - 1 - fields = '__all__' + fields = "__all__" field_class = NestedSerializer field_kwargs = get_nested_relation_kwargs(relation_info) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 8db9c81ed..9fd09a9b1 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -28,135 +28,109 @@ from django.utils import six from rest_framework import ISO_8601 + DEFAULTS = { # Base API policies - 'DEFAULT_RENDERER_CLASSES': ( - 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', + "DEFAULT_RENDERER_CLASSES": ( + "rest_framework.renderers.JSONRenderer", + "rest_framework.renderers.BrowsableAPIRenderer", ), - 'DEFAULT_PARSER_CLASSES': ( - 'rest_framework.parsers.JSONParser', - 'rest_framework.parsers.FormParser', - 'rest_framework.parsers.MultiPartParser' + "DEFAULT_PARSER_CLASSES": ( + "rest_framework.parsers.JSONParser", + "rest_framework.parsers.FormParser", + "rest_framework.parsers.MultiPartParser", ), - 'DEFAULT_AUTHENTICATION_CLASSES': ( - 'rest_framework.authentication.SessionAuthentication', - 'rest_framework.authentication.BasicAuthentication' + "DEFAULT_AUTHENTICATION_CLASSES": ( + "rest_framework.authentication.SessionAuthentication", + "rest_framework.authentication.BasicAuthentication", ), - 'DEFAULT_PERMISSION_CLASSES': ( - 'rest_framework.permissions.AllowAny', - ), - 'DEFAULT_THROTTLE_CLASSES': (), - 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', - 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', - 'DEFAULT_VERSIONING_CLASS': None, - + "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",), + "DEFAULT_THROTTLE_CLASSES": (), + "DEFAULT_CONTENT_NEGOTIATION_CLASS": "rest_framework.negotiation.DefaultContentNegotiation", + "DEFAULT_METADATA_CLASS": "rest_framework.metadata.SimpleMetadata", + "DEFAULT_VERSIONING_CLASS": None, # Generic view behavior - 'DEFAULT_PAGINATION_CLASS': None, - 'DEFAULT_FILTER_BACKENDS': (), - + "DEFAULT_PAGINATION_CLASS": None, + "DEFAULT_FILTER_BACKENDS": (), # Schema - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema', - + "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.AutoSchema", # Throttling - 'DEFAULT_THROTTLE_RATES': { - 'user': None, - 'anon': None, - }, - 'NUM_PROXIES': None, - + "DEFAULT_THROTTLE_RATES": {"user": None, "anon": None}, + "NUM_PROXIES": None, # Pagination - 'PAGE_SIZE': None, - + "PAGE_SIZE": None, # Filtering - 'SEARCH_PARAM': 'search', - 'ORDERING_PARAM': 'ordering', - + "SEARCH_PARAM": "search", + "ORDERING_PARAM": "ordering", # Versioning - 'DEFAULT_VERSION': None, - 'ALLOWED_VERSIONS': None, - 'VERSION_PARAM': 'version', - + "DEFAULT_VERSION": None, + "ALLOWED_VERSIONS": None, + "VERSION_PARAM": "version", # Authentication - 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', - 'UNAUTHENTICATED_TOKEN': None, - + "UNAUTHENTICATED_USER": "django.contrib.auth.models.AnonymousUser", + "UNAUTHENTICATED_TOKEN": None, # View configuration - 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', - 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', - + "VIEW_NAME_FUNCTION": "rest_framework.views.get_view_name", + "VIEW_DESCRIPTION_FUNCTION": "rest_framework.views.get_view_description", # Exception handling - 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', - 'NON_FIELD_ERRORS_KEY': 'non_field_errors', - + "EXCEPTION_HANDLER": "rest_framework.views.exception_handler", + "NON_FIELD_ERRORS_KEY": "non_field_errors", # Testing - 'TEST_REQUEST_RENDERER_CLASSES': ( - 'rest_framework.renderers.MultiPartRenderer', - 'rest_framework.renderers.JSONRenderer' + "TEST_REQUEST_RENDERER_CLASSES": ( + "rest_framework.renderers.MultiPartRenderer", + "rest_framework.renderers.JSONRenderer", ), - 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', - + "TEST_REQUEST_DEFAULT_FORMAT": "multipart", # Hyperlink settings - 'URL_FORMAT_OVERRIDE': 'format', - 'FORMAT_SUFFIX_KWARG': 'format', - 'URL_FIELD_NAME': 'url', - + "URL_FORMAT_OVERRIDE": "format", + "FORMAT_SUFFIX_KWARG": "format", + "URL_FIELD_NAME": "url", # Input and output formats - 'DATE_FORMAT': ISO_8601, - 'DATE_INPUT_FORMATS': (ISO_8601,), - - 'DATETIME_FORMAT': ISO_8601, - 'DATETIME_INPUT_FORMATS': (ISO_8601,), - - 'TIME_FORMAT': ISO_8601, - 'TIME_INPUT_FORMATS': (ISO_8601,), - + "DATE_FORMAT": ISO_8601, + "DATE_INPUT_FORMATS": (ISO_8601,), + "DATETIME_FORMAT": ISO_8601, + "DATETIME_INPUT_FORMATS": (ISO_8601,), + "TIME_FORMAT": ISO_8601, + "TIME_INPUT_FORMATS": (ISO_8601,), # Encoding - 'UNICODE_JSON': True, - 'COMPACT_JSON': True, - 'STRICT_JSON': True, - 'COERCE_DECIMAL_TO_STRING': True, - 'UPLOADED_FILES_USE_URL': True, - + "UNICODE_JSON": True, + "COMPACT_JSON": True, + "STRICT_JSON": True, + "COERCE_DECIMAL_TO_STRING": True, + "UPLOADED_FILES_USE_URL": True, # Browseable API - 'HTML_SELECT_CUTOFF': 1000, - 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", - + "HTML_SELECT_CUTOFF": 1000, + "HTML_SELECT_CUTOFF_TEXT": "More than {count} items...", # Schemas - 'SCHEMA_COERCE_PATH_PK': True, - 'SCHEMA_COERCE_METHOD_NAMES': { - 'retrieve': 'read', - 'destroy': 'delete' - }, + "SCHEMA_COERCE_PATH_PK": True, + "SCHEMA_COERCE_METHOD_NAMES": {"retrieve": "read", "destroy": "delete"}, } # List of settings that may be in string import notation. IMPORT_STRINGS = ( - 'DEFAULT_RENDERER_CLASSES', - 'DEFAULT_PARSER_CLASSES', - 'DEFAULT_AUTHENTICATION_CLASSES', - 'DEFAULT_PERMISSION_CLASSES', - 'DEFAULT_THROTTLE_CLASSES', - 'DEFAULT_CONTENT_NEGOTIATION_CLASS', - 'DEFAULT_METADATA_CLASS', - 'DEFAULT_VERSIONING_CLASS', - 'DEFAULT_PAGINATION_CLASS', - 'DEFAULT_FILTER_BACKENDS', - 'DEFAULT_SCHEMA_CLASS', - 'EXCEPTION_HANDLER', - 'TEST_REQUEST_RENDERER_CLASSES', - 'UNAUTHENTICATED_USER', - 'UNAUTHENTICATED_TOKEN', - 'VIEW_NAME_FUNCTION', - 'VIEW_DESCRIPTION_FUNCTION' + "DEFAULT_RENDERER_CLASSES", + "DEFAULT_PARSER_CLASSES", + "DEFAULT_AUTHENTICATION_CLASSES", + "DEFAULT_PERMISSION_CLASSES", + "DEFAULT_THROTTLE_CLASSES", + "DEFAULT_CONTENT_NEGOTIATION_CLASS", + "DEFAULT_METADATA_CLASS", + "DEFAULT_VERSIONING_CLASS", + "DEFAULT_PAGINATION_CLASS", + "DEFAULT_FILTER_BACKENDS", + "DEFAULT_SCHEMA_CLASS", + "EXCEPTION_HANDLER", + "TEST_REQUEST_RENDERER_CLASSES", + "UNAUTHENTICATED_USER", + "UNAUTHENTICATED_TOKEN", + "VIEW_NAME_FUNCTION", + "VIEW_DESCRIPTION_FUNCTION", ) # List of settings that have been removed -REMOVED_SETTINGS = ( - "PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY", -) +REMOVED_SETTINGS = ("PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY") def perform_import(val, setting_name): @@ -179,11 +153,16 @@ def import_from_string(val, setting_name): """ try: # Nod to tastypie's use of importlib. - module_path, class_name = val.rsplit('.', 1) + module_path, class_name = val.rsplit(".", 1) module = import_module(module_path) return getattr(module, class_name) except (ImportError, AttributeError) as e: - msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import '%s' for API setting '%s'. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -198,6 +177,7 @@ class APISettings(object): Any setting with string import paths will be automatically resolved and return the class, rather than the string literal. """ + def __init__(self, user_settings=None, defaults=None, import_strings=None): if user_settings: self._user_settings = self.__check_user_settings(user_settings) @@ -207,8 +187,8 @@ class APISettings(object): @property def user_settings(self): - if not hasattr(self, '_user_settings'): - self._user_settings = getattr(settings, 'REST_FRAMEWORK', {}) + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "REST_FRAMEWORK", {}) return self._user_settings def __getattr__(self, attr): @@ -235,23 +215,26 @@ class APISettings(object): SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/" for setting in REMOVED_SETTINGS: if setting in user_settings: - raise RuntimeError("The '%s' setting has been removed. Please refer to '%s' for available settings." % (setting, SETTINGS_DOC)) + raise RuntimeError( + "The '%s' setting has been removed. Please refer to '%s' for available settings." + % (setting, SETTINGS_DOC) + ) return user_settings def reload(self): for attr in self._cached_attrs: delattr(self, attr) self._cached_attrs.clear() - if hasattr(self, '_user_settings'): - delattr(self, '_user_settings') + if hasattr(self, "_user_settings"): + delattr(self, "_user_settings") api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS) def reload_api_settings(*args, **kwargs): - setting = kwargs['setting'] - if setting == 'REST_FRAMEWORK': + setting = kwargs["setting"] + if setting == "REST_FRAMEWORK": api_settings.reload() diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index f48675d5e..b71599a5d 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -15,22 +15,23 @@ from rest_framework.compat import apply_markdown, pygments_highlight from rest_framework.renderers import HTMLFormRenderer from rest_framework.utils.urls import replace_query_param + register = template.Library() # Regex for adding classes to html snippets class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') -@register.tag(name='code') +@register.tag(name="code") def highlight_code(parser, token): code = token.split_contents()[-1] - nodelist = parser.parse(('endcode',)) + nodelist = parser.parse(("endcode",)) parser.delete_first_token() return CodeNode(code, nodelist) class CodeNode(template.Node): - style = 'emacs' + style = "emacs" def __init__(self, lang, code): self.lang = lang @@ -43,24 +44,17 @@ class CodeNode(template.Node): @register.filter() def with_location(fields, location): - return [ - field for field in fields - if field.location == location - ] + return [field for field in fields if field.location == location] @register.simple_tag def form_for_link(link): import coreschema - properties = OrderedDict([ - (field.name, field.schema or coreschema.String()) - for field in link.fields - ]) - required = [ - field.name - for field in link.fields - if field.required - ] + + properties = OrderedDict( + [(field.name, field.schema or coreschema.String()) for field in link.fields] + ) + required = [field.name for field in link.fields if field.required] schema = coreschema.Object(properties=properties, required=required) return mark_safe(coreschema.render_to_form(schema)) @@ -79,14 +73,14 @@ def get_pagination_html(pager): @register.simple_tag def render_form(serializer, template_pack=None): - style = {'template_pack': template_pack} if template_pack else {} + style = {"template_pack": template_pack} if template_pack else {} renderer = HTMLFormRenderer() - return renderer.render(serializer.data, None, {'style': style}) + return renderer.render(serializer.data, None, {"style": style}) @register.simple_tag def render_field(field, style): - renderer = style.get('renderer', HTMLFormRenderer()) + renderer = style.get("renderer", HTMLFormRenderer()) return renderer.render_field(field, style) @@ -96,9 +90,9 @@ def optional_login(request): Include a login snippet if REST framework's login view is in the URLconf. """ try: - login_url = reverse('rest_framework:login') + login_url = reverse("rest_framework:login") except NoReverseMatch: - return '' + return "" snippet = "
  • Log in
  • " snippet = format_html(snippet, href=login_url, next=escape(request.path)) @@ -112,9 +106,9 @@ def optional_docs_login(request): Include a login snippet if REST framework's login view is in the URLconf. """ try: - login_url = reverse('rest_framework:login') + login_url = reverse("rest_framework:login") except NoReverseMatch: - return 'log in' + return "log in" snippet = "log in" snippet = format_html(snippet, href=login_url, next=escape(request.path)) @@ -128,7 +122,7 @@ def optional_logout(request, user): Include a logout snippet if REST framework's logout view is in the URLconf. """ try: - logout_url = reverse('rest_framework:logout') + logout_url = reverse("rest_framework:logout") except NoReverseMatch: snippet = format_html('', user=escape(user)) return mark_safe(snippet) @@ -142,7 +136,9 @@ def optional_logout(request, user):
  • Log out
  • """ - snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path)) + snippet = format_html( + snippet, user=escape(user), href=logout_url, next=escape(request.path) + ) return mark_safe(snippet) @@ -160,16 +156,13 @@ def add_query_param(request, key, val): @register.filter def as_string(value): if value is None: - return '' - return '%s' % value + return "" + return "%s" % value @register.filter def as_list_of_strings(value): - return [ - '' if (item is None) else ('%s' % item) - for item in value - ] + return ["" if (item is None) else ("%s" % item) for item in value] @register.filter @@ -190,45 +183,52 @@ def add_class(value, css_class): html = six.text_type(value) match = class_re.search(html) if match: - m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class, - css_class, css_class), - match.group(1)) + m = re.search( + r"^%s$|^%s\s|\s%s\s|\s%s$" % (css_class, css_class, css_class, css_class), + match.group(1), + ) if not m: - return mark_safe(class_re.sub(match.group(1) + " " + css_class, - html)) + return mark_safe(class_re.sub(match.group(1) + " " + css_class, html)) else: - return mark_safe(html.replace('>', ' class="%s">' % css_class, 1)) + return mark_safe(html.replace(">", ' class="%s">' % css_class, 1)) return value @register.filter def format_value(value): - if getattr(value, 'is_hyperlink', False): + if getattr(value, "is_hyperlink", False): name = six.text_type(value.obj) - return mark_safe('%s' % (value, escape(name))) + return mark_safe("%s" % (value, escape(name))) if value is None or isinstance(value, bool): - return mark_safe('%s' % {True: 'true', False: 'false', None: 'null'}[value]) + return mark_safe( + "%s" % {True: "true", False: "false", None: "null"}[value] + ) elif isinstance(value, list): if any([isinstance(item, (list, dict)) for item in value]): - template = loader.get_template('rest_framework/admin/list_value.html') + template = loader.get_template("rest_framework/admin/list_value.html") else: - template = loader.get_template('rest_framework/admin/simple_list_value.html') - context = {'value': value} + template = loader.get_template( + "rest_framework/admin/simple_list_value.html" + ) + context = {"value": value} return template.render(context) elif isinstance(value, dict): - template = loader.get_template('rest_framework/admin/dict_value.html') - context = {'value': value} + template = loader.get_template("rest_framework/admin/dict_value.html") + context = {"value": value} return template.render(context) elif isinstance(value, six.string_types): - if ( - (value.startswith('http:') or value.startswith('https:')) and not - re.search(r'\s', value) + if (value.startswith("http:") or value.startswith("https:")) and not re.search( + r"\s", value ): - return mark_safe('{value}'.format(value=escape(value))) - elif '@' in value and not re.search(r'\s', value): - return mark_safe('{value}'.format(value=escape(value))) - elif '\n' in value: - return mark_safe('
    %s
    ' % escape(value)) + return mark_safe( + '{value}'.format(value=escape(value)) + ) + elif "@" in value and not re.search(r"\s", value): + return mark_safe( + '{value}'.format(value=escape(value)) + ) + elif "\n" in value: + return mark_safe("
    %s
    " % escape(value)) return six.text_type(value) @@ -266,7 +266,7 @@ def schema_links(section, sec_key=None): """ Recursively find every link in a schema, even nested. """ - NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys + NESTED_FORMAT = "%s > %s" # this format is used in docs/js/api.js:normalizeKeys links = section.links if section.data: data = section.data.items() @@ -287,20 +287,30 @@ def schema_links(section, sec_key=None): @register.filter def add_nested_class(value): if isinstance(value, dict): - return 'class=nested' - if isinstance(value, list) and any([isinstance(item, (list, dict)) for item in value]): - return 'class=nested' - return '' + return "class=nested" + if isinstance(value, list) and any( + [isinstance(item, (list, dict)) for item in value] + ): + return "class=nested" + return "" # Bunch of stuff cloned from urlize -TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"] -WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), - ('"', '"'), ("'", "'")] -word_split_re = re.compile(r'(\s+)') -simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE) -simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) -simple_email_re = re.compile(r'^\S+@\S+\.\S+$') +TRAILING_PUNCTUATION = [".", ",", ":", ";", ".)", '"', "']", "'}", "'"] +WRAPPING_PUNCTUATION = [ + ("(", ")"), + ("<", ">"), + ("[", "]"), + ("<", ">"), + ('"', '"'), + ("'", "'"), +] +word_split_re = re.compile(r"(\s+)") +simple_url_re = re.compile(r"^https?://\[?\w", re.IGNORECASE) +simple_url_2_re = re.compile( + r"^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$", re.IGNORECASE +) +simple_email_re = re.compile(r"^\S+@\S+\.\S+$") def smart_urlquote_wrapper(matched_url): @@ -332,8 +342,13 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru If autoescape is True, the link text and URLs will get autoescaped. """ + def trim_url(x, limit=trim_url_limit): - return limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x + return ( + limit is not None + and (len(x) > limit and ("%s..." % x[: max(0, limit - 3)])) + or x + ) safe_input = isinstance(text, SafeData) @@ -344,40 +359,40 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru words = word_split_re.split(force_text(text)) for i, word in enumerate(words): - if '.' in word or '@' in word or ':' in word: + if "." in word or "@" in word or ":" in word: # Deal with punctuation. - lead, middle, trail = '', word, '' + lead, middle, trail = "", word, "" for punctuation in TRAILING_PUNCTUATION: if middle.endswith(punctuation): - middle = middle[:-len(punctuation)] + middle = middle[: -len(punctuation)] trail = punctuation + trail for opening, closing in WRAPPING_PUNCTUATION: if middle.startswith(opening): - middle = middle[len(opening):] + middle = middle[len(opening) :] lead = lead + opening # Keep parentheses at the end only if they're balanced. if ( - middle.endswith(closing) and - middle.count(closing) == middle.count(opening) + 1 + middle.endswith(closing) + and middle.count(closing) == middle.count(opening) + 1 ): - middle = middle[:-len(closing)] + middle = middle[: -len(closing)] trail = closing + trail # Make URL we want to point to. url = None - nofollow_attr = ' rel="nofollow"' if nofollow else '' + nofollow_attr = ' rel="nofollow"' if nofollow else "" if simple_url_re.match(middle): url = smart_urlquote_wrapper(middle) elif simple_url_2_re.match(middle): - url = smart_urlquote_wrapper('http://%s' % middle) - elif ':' not in middle and simple_email_re.match(middle): - local, domain = middle.rsplit('@', 1) + url = smart_urlquote_wrapper("http://%s" % middle) + elif ":" not in middle and simple_email_re.match(middle): + local, domain = middle.rsplit("@", 1) try: - domain = domain.encode('idna').decode('ascii') + domain = domain.encode("idna").decode("ascii") except UnicodeError: continue - url = 'mailto:%s@%s' % (local, domain) - nofollow_attr = '' + url = "mailto:%s@%s" % (local, domain) + nofollow_attr = "" # Make link. if url: @@ -385,12 +400,12 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru lead, trail = conditional_escape(lead), conditional_escape(trail) url, trimmed = conditional_escape(url), conditional_escape(trimmed) middle = '%s' % (url, nofollow_attr, trimmed) - words[i] = '%s%s%s' % (lead, middle, trail) + words[i] = "%s%s%s" % (lead, middle, trail) else: words[i] = conditional_escape(word) else: words[i] = conditional_escape(word) - return mark_safe(''.join(words)) + return mark_safe("".join(words)) @register.filter @@ -399,6 +414,6 @@ def break_long_headers(header): Breaks headers longer than 160 characters (~page length) when possible (are comma separated) """ - if len(header) > 160 and ',' in header: - header = mark_safe('
    ' + ',
    '.join(header.split(','))) + if len(header) > 160 and "," in header: + header = mark_safe("
    " + ",
    ".join(header.split(","))) return header diff --git a/rest_framework/test.py b/rest_framework/test.py index edacf0066..0dd6944ec 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -11,9 +11,11 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler from django.test import override_settings, testcases -from django.test.client import Client as DjangoClient -from django.test.client import ClientHandler -from django.test.client import RequestFactory as DjangoRequestFactory +from django.test.client import ( + Client as DjangoClient, + ClientHandler, + RequestFactory as DjangoRequestFactory, +) from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode @@ -28,6 +30,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): return self.getheaders(key) @@ -48,6 +51,7 @@ if requests is not None: A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. """ + def __init__(self): self.app = WSGIHandler() self.factory = DjangoRequestFactory() @@ -62,19 +66,19 @@ if requests is not None: # Set request content, if any exists. if request.body is not None: - if hasattr(request.body, 'read'): - kwargs['data'] = request.body.read() + if hasattr(request.body, "read"): + kwargs["data"] = request.body.read() else: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] + kwargs["data"] = request.body + if "content-type" in request.headers: + kwargs["content_type"] = request.headers["content-type"] # Set request headers. for key, value in request.headers.items(): key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + if key in ("CONNECTION", "CONTENT-LENGTH", "CONTENT-TYPE"): continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value + kwargs["HTTP_%s" % key.replace("-", "_")] = value return self.factory.generic(method, url, **kwargs).environ @@ -85,20 +89,20 @@ if requests is not None: raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): - status, _, reason = wsgi_status.partition(' ') - raw_kwargs['status'] = int(status) - raw_kwargs['reason'] = reason - raw_kwargs['headers'] = wsgi_headers - raw_kwargs['version'] = 11 - raw_kwargs['preload_content'] = False - raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + status, _, reason = wsgi_status.partition(" ") + raw_kwargs["status"] = int(status) + raw_kwargs["reason"] = reason + raw_kwargs["headers"] = wsgi_headers + raw_kwargs["version"] = 11 + raw_kwargs["preload_content"] = False + raw_kwargs["original_response"] = MockOriginalResponse(wsgi_headers) # Make the outgoing request via WSGI. environ = self.get_environ(request) wsgi_response = self.app(environ, start_response) # Build the underlying urllib3.HTTPResponse - raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw_kwargs["body"] = io.BytesIO(b"".join(wsgi_response)) raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) # Build the requests.Response @@ -111,33 +115,47 @@ if requests is not None: def __init__(self, *args, **kwargs): super(RequestsClient, self).__init__(*args, **kwargs) adapter = DjangoTestAdapter() - self.mount('http://', adapter) - self.mount('https://', adapter) + self.mount("http://", adapter) + self.mount("https://", adapter) def request(self, method, url, *args, **kwargs): - if not url.startswith('http'): - raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url) + if not url.startswith("http"): + raise ValueError( + 'Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' + % url + ) return super(RequestsClient, self).request(method, url, *args, **kwargs) + else: + def RequestsClient(*args, **kwargs): - raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.') + raise ImproperlyConfigured( + "requests must be installed in order to use RequestsClient." + ) if coreapi is not None: + class CoreAPIClient(coreapi.Client): def __init__(self, *args, **kwargs): self._session = RequestsClient() - kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)] + kwargs["transports"] = [ + coreapi.transports.HTTPTransport(session=self.session) + ] return super(CoreAPIClient, self).__init__(*args, **kwargs) @property def session(self): return self._session + else: + def CoreAPIClient(*args, **kwargs): - raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.') + raise ImproperlyConfigured( + "coreapi must be installed in order to use CoreAPIClient." + ) class APIRequestFactory(DjangoRequestFactory): @@ -157,11 +175,11 @@ class APIRequestFactory(DjangoRequestFactory): """ if data is None: - return ('', content_type) + return ("", content_type) - assert format is None or content_type is None, ( - 'You may not set both `format` and `content_type`.' - ) + assert ( + format is None or content_type is None + ), "You may not set both `format` and `content_type`." if content_type: # Content type specified explicitly, treat data as a raw bytestring @@ -175,7 +193,7 @@ class APIRequestFactory(DjangoRequestFactory): "Set TEST_REQUEST_RENDERER_CLASSES to enable " "extra request formats.".format( format, - ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes]) + ", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]), ) ) @@ -195,47 +213,53 @@ class APIRequestFactory(DjangoRequestFactory): return ret, content_type def get(self, path, data=None, **extra): - r = { - 'QUERY_STRING': urlencode(data or {}, doseq=True), - } - if not data and '?' in path: + r = {"QUERY_STRING": urlencode(data or {}, doseq=True)} + if not data and "?" in path: # Fix to support old behavior where you have the arguments in the # url. See #1461. - query_string = force_bytes(path.split('?')[1]) + query_string = force_bytes(path.split("?")[1]) if six.PY3: - query_string = query_string.decode('iso-8859-1') - r['QUERY_STRING'] = query_string + query_string = query_string.decode("iso-8859-1") + r["QUERY_STRING"] = query_string r.update(extra) - return self.generic('GET', path, **r) + return self.generic("GET", path, **r) def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) - return self.generic('POST', path, data, content_type, **extra) + return self.generic("POST", path, data, content_type, **extra) def put(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) - return self.generic('PUT', path, data, content_type, **extra) + return self.generic("PUT", path, data, content_type, **extra) def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) - return self.generic('PATCH', path, data, content_type, **extra) + return self.generic("PATCH", path, data, content_type, **extra) def delete(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) - return self.generic('DELETE', path, data, content_type, **extra) + return self.generic("DELETE", path, data, content_type, **extra) def options(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) - return self.generic('OPTIONS', path, data, content_type, **extra) + return self.generic("OPTIONS", path, data, content_type, **extra) - def generic(self, method, path, data='', - content_type='application/octet-stream', secure=False, **extra): + def generic( + self, + method, + path, + data="", + content_type="application/octet-stream", + secure=False, + **extra + ): # Include the CONTENT_TYPE, regardless of whether or not data is empty. if content_type is not None: - extra['CONTENT_TYPE'] = str(content_type) + extra["CONTENT_TYPE"] = str(content_type) return super(APIRequestFactory, self).generic( - method, path, data, content_type, secure, **extra) + method, path, data, content_type, secure, **extra + ) def request(self, **kwargs): request = super(APIRequestFactory, self).request(**kwargs) @@ -294,42 +318,52 @@ class APIClient(APIRequestFactory, DjangoClient): response = self._handle_redirects(response, **extra) return response - def post(self, path, data=None, format=None, content_type=None, - follow=False, **extra): + def post( + self, path, data=None, format=None, content_type=None, follow=False, **extra + ): response = super(APIClient, self).post( - path, data=data, format=format, content_type=content_type, **extra) + path, data=data, format=format, content_type=content_type, **extra + ) if follow: response = self._handle_redirects(response, **extra) return response - def put(self, path, data=None, format=None, content_type=None, - follow=False, **extra): + def put( + self, path, data=None, format=None, content_type=None, follow=False, **extra + ): response = super(APIClient, self).put( - path, data=data, format=format, content_type=content_type, **extra) + path, data=data, format=format, content_type=content_type, **extra + ) if follow: response = self._handle_redirects(response, **extra) return response - def patch(self, path, data=None, format=None, content_type=None, - follow=False, **extra): + def patch( + self, path, data=None, format=None, content_type=None, follow=False, **extra + ): response = super(APIClient, self).patch( - path, data=data, format=format, content_type=content_type, **extra) + path, data=data, format=format, content_type=content_type, **extra + ) if follow: response = self._handle_redirects(response, **extra) return response - def delete(self, path, data=None, format=None, content_type=None, - follow=False, **extra): + def delete( + self, path, data=None, format=None, content_type=None, follow=False, **extra + ): response = super(APIClient, self).delete( - path, data=data, format=format, content_type=content_type, **extra) + path, data=data, format=format, content_type=content_type, **extra + ) if follow: response = self._handle_redirects(response, **extra) return response - def options(self, path, data=None, format=None, content_type=None, - follow=False, **extra): + def options( + self, path, data=None, format=None, content_type=None, follow=False, **extra + ): response = super(APIClient, self).options( - path, data=data, format=format, content_type=content_type, **extra) + path, data=data, format=format, content_type=content_type, **extra + ) if follow: response = self._handle_redirects(response, **extra) return response @@ -377,13 +411,14 @@ class URLPatternsTestCase(testcases.SimpleTestCase): def test_something_else(self): ... """ + @classmethod def setUpClass(cls): # Get the module of the TestCase subclass cls._module = import_module(cls.__module__) cls._override = override_settings(ROOT_URLCONF=cls.__module__) - if hasattr(cls._module, 'urlpatterns'): + if hasattr(cls._module, "urlpatterns"): cls._module_urlpatterns = cls._module.urlpatterns cls._module.urlpatterns = cls.urlpatterns @@ -396,7 +431,7 @@ class URLPatternsTestCase(testcases.SimpleTestCase): super(URLPatternsTestCase, cls).tearDownClass() cls._override.disable() - if hasattr(cls, '_module_urlpatterns'): + if hasattr(cls, "_module_urlpatterns"): cls._module.urlpatterns = cls._module_urlpatterns else: del cls._module.urlpatterns diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 834ced148..0f7f741fd 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -20,7 +20,7 @@ class BaseThrottle(object): """ Return `True` if the request should be allowed, `False` otherwise. """ - raise NotImplementedError('.allow_request() must be overridden') + raise NotImplementedError(".allow_request() must be overridden") def get_ident(self, request): """ @@ -28,18 +28,18 @@ class BaseThrottle(object): if present and number of proxies is > 0. If not use all of HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR. """ - xff = request.META.get('HTTP_X_FORWARDED_FOR') - remote_addr = request.META.get('REMOTE_ADDR') + xff = request.META.get("HTTP_X_FORWARDED_FOR") + remote_addr = request.META.get("REMOTE_ADDR") num_proxies = api_settings.NUM_PROXIES if num_proxies is not None: if num_proxies == 0 or xff is None: return remote_addr - addrs = xff.split(',') + addrs = xff.split(",") client_addr = addrs[-min(num_proxies, len(addrs))] return client_addr.strip() - return ''.join(xff.split()) if xff else remote_addr + return "".join(xff.split()) if xff else remote_addr def wait(self): """ @@ -61,14 +61,15 @@ class SimpleRateThrottle(BaseThrottle): Previous request information used for throttling is stored in the cache. """ + cache = default_cache timer = time.time - cache_format = 'throttle_%(scope)s_%(ident)s' + cache_format = "throttle_%(scope)s_%(ident)s" scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES def __init__(self): - if not getattr(self, 'rate', None): + if not getattr(self, "rate", None): self.rate = self.get_rate() self.num_requests, self.duration = self.parse_rate(self.rate) @@ -79,15 +80,17 @@ class SimpleRateThrottle(BaseThrottle): May return `None` if the request should not be throttled. """ - raise NotImplementedError('.get_cache_key() must be overridden') + raise NotImplementedError(".get_cache_key() must be overridden") def get_rate(self): """ Determine the string representation of the allowed request rate. """ - if not getattr(self, 'scope', None): - msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % - self.__class__.__name__) + if not getattr(self, "scope", None): + msg = ( + "You must set either `.scope` or `.rate` for '%s' throttle" + % self.__class__.__name__ + ) raise ImproperlyConfigured(msg) try: @@ -103,9 +106,9 @@ class SimpleRateThrottle(BaseThrottle): """ if rate is None: return (None, None) - num, period = rate.split('/') + num, period = rate.split("/") num_requests = int(num) - duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]] return (num_requests, duration) def allow_request(self, request, view): @@ -170,15 +173,16 @@ class AnonRateThrottle(SimpleRateThrottle): The IP address of the request will be used as the unique cache key. """ - scope = 'anon' + + scope = "anon" def get_cache_key(self, request, view): if request.user.is_authenticated: return None # Only throttle unauthenticated requests. return self.cache_format % { - 'scope': self.scope, - 'ident': self.get_ident(request) + "scope": self.scope, + "ident": self.get_ident(request), } @@ -190,7 +194,8 @@ class UserRateThrottle(SimpleRateThrottle): authenticated. For anonymous requests, the IP address of the request will be used. """ - scope = 'user' + + scope = "user" def get_cache_key(self, request, view): if request.user.is_authenticated: @@ -198,10 +203,7 @@ class UserRateThrottle(SimpleRateThrottle): else: ident = self.get_ident(request) - return self.cache_format % { - 'scope': self.scope, - 'ident': ident - } + return self.cache_format % {"scope": self.scope, "ident": ident} class ScopedRateThrottle(SimpleRateThrottle): @@ -211,7 +213,8 @@ class ScopedRateThrottle(SimpleRateThrottle): throttled. The unique cache key will be generated by concatenating the user id of the request, and the scope of the view being accessed. """ - scope_attr = 'throttle_scope' + + scope_attr = "throttle_scope" def __init__(self): # Override the usual SimpleRateThrottle, because we can't determine @@ -246,7 +249,4 @@ class ScopedRateThrottle(SimpleRateThrottle): else: ident = self.get_ident(request) - return self.cache_format % { - 'scope': self.scope, - 'ident': ident - } + return self.cache_format % {"scope": self.scope, "ident": ident} diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index ab3a74978..076f03b38 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -3,7 +3,11 @@ from __future__ import unicode_literals from django.conf.urls import include, url from rest_framework.compat import ( - URLResolver, get_regex_pattern, is_route_pattern, path, register_converter + URLResolver, + get_regex_pattern, + is_route_pattern, + path, + register_converter, ) from rest_framework.settings import api_settings @@ -13,7 +17,7 @@ def _get_format_path_converter(suffix_kwarg, allowed): if len(allowed) == 1: allowed_pattern = allowed[0] else: - allowed_pattern = '(?:%s)' % '|'.join(allowed) + allowed_pattern = "(?:%s)" % "|".join(allowed) suffix_pattern = r"\.%s/?" % allowed_pattern else: suffix_pattern = r"\.[a-z0-9]+/?" @@ -22,19 +26,21 @@ def _get_format_path_converter(suffix_kwarg, allowed): regex = suffix_pattern def to_python(self, value): - return value.strip('./') + return value.strip("./") def to_url(self, value): - return '.' + value + '/' + return "." + value + "/" - converter_name = 'drf_format_suffix' + converter_name = "drf_format_suffix" if allowed: - converter_name += '_' + '_'.join(allowed) + converter_name += "_" + "_".join(allowed) return converter_name, FormatSuffixConverter -def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route=None): +def apply_suffix_patterns( + urlpatterns, suffix_pattern, suffix_required, suffix_route=None +): ret = [] for urlpattern in urlpatterns: if isinstance(urlpattern, URLResolver): @@ -44,23 +50,28 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r app_name = urlpattern.app_name kwargs = urlpattern.default_kwargs # Add in the included patterns, after applying the suffixes - patterns = apply_suffix_patterns(urlpattern.url_patterns, - suffix_pattern, - suffix_required, - suffix_route) + patterns = apply_suffix_patterns( + urlpattern.url_patterns, suffix_pattern, suffix_required, suffix_route + ) # if the original pattern was a RoutePattern we need to preserve it if is_route_pattern(urlpattern): assert path is not None route = str(urlpattern.pattern) - new_pattern = path(route, include((patterns, app_name), namespace), kwargs) + new_pattern = path( + route, include((patterns, app_name), namespace), kwargs + ) else: - new_pattern = url(regex, include((patterns, app_name), namespace), kwargs) + new_pattern = url( + regex, include((patterns, app_name), namespace), kwargs + ) ret.append(new_pattern) else: # Regular URL pattern - regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern + regex = ( + get_regex_pattern(urlpattern).rstrip("$").rstrip("/") + suffix_pattern + ) view = urlpattern.callback kwargs = urlpattern.default_args name = urlpattern.name @@ -72,7 +83,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r if is_route_pattern(urlpattern): assert path is not None assert suffix_route is not None - route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route + route = str(urlpattern.pattern).rstrip("$").rstrip("/") + suffix_route new_pattern = path(route, view, kwargs, name) else: new_pattern = url(regex, view, kwargs, name) @@ -103,17 +114,21 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None): if len(allowed) == 1: allowed_pattern = allowed[0] else: - allowed_pattern = '(%s)' % '|'.join(allowed) - suffix_pattern = r'\.(?P<%s>%s)/?$' % (suffix_kwarg, allowed_pattern) + allowed_pattern = "(%s)" % "|".join(allowed) + suffix_pattern = r"\.(?P<%s>%s)/?$" % (suffix_kwarg, allowed_pattern) else: - suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg + suffix_pattern = r"\.(?P<%s>[a-z0-9]+)/?$" % suffix_kwarg if path and register_converter: - converter_name, suffix_converter = _get_format_path_converter(suffix_kwarg, allowed) + converter_name, suffix_converter = _get_format_path_converter( + suffix_kwarg, allowed + ) register_converter(suffix_converter, converter_name) - suffix_route = '<%s:%s>' % (converter_name, suffix_kwarg) + suffix_route = "<%s:%s>" % (converter_name, suffix_kwarg) else: suffix_route = None - return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route) + return apply_suffix_patterns( + urlpatterns, suffix_pattern, suffix_required, suffix_route + ) diff --git a/rest_framework/urls.py b/rest_framework/urls.py index 0e4c2661b..6932e5e27 100644 --- a/rest_framework/urls.py +++ b/rest_framework/urls.py @@ -16,8 +16,13 @@ from __future__ import unicode_literals from django.conf.urls import url from django.contrib.auth import views -app_name = 'rest_framework' + +app_name = "rest_framework" urlpatterns = [ - url(r'^login/$', views.LoginView.as_view(template_name='rest_framework/login.html'), name='login'), - url(r'^logout/$', views.LogoutView.as_view(), name='logout'), + url( + r"^login/$", + views.LoginView.as_view(template_name="rest_framework/login.html"), + name="login", + ), + url(r"^logout/$", views.LogoutView.as_view(), name="logout"), ] diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index e0374ffd0..b3626a52e 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -23,8 +23,8 @@ def get_breadcrumbs(url, request=None): else: # Check if this is a REST framework view, # and if so add it to the breadcrumbs - cls = getattr(view, 'cls', None) - initkwargs = getattr(view, 'initkwargs', {}) + cls = getattr(view, "cls", None) + initkwargs = getattr(view, "initkwargs", {}) if cls is not None and issubclass(cls, APIView): # Don't list the same view twice in a row. # Probably an optional trailing slash. @@ -35,21 +35,21 @@ def get_breadcrumbs(url, request=None): breadcrumbs_list.insert(0, (name, insert_url)) seen.append(view) - if url == '': + if url == "": # All done return breadcrumbs_list - elif url.endswith('/'): + elif url.endswith("/"): # Drop trailing slash off the end and continue to try to # resolve more breadcrumbs - url = url.rstrip('/') + url = url.rstrip("/") return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) # Drop trailing non-slash off the end and continue to try to # resolve more breadcrumbs - url = url[:url.rfind('/') + 1] + url = url[: url.rfind("/") + 1] return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen) - prefix = get_script_prefix().rstrip('/') - url = url[len(prefix):] + prefix = get_script_prefix().rstrip("/") + url = url[len(prefix) :] return breadcrumbs_recursive(url, [], prefix, []) diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py index d8f4aeb4e..66ecfb133 100644 --- a/rest_framework/utils/encoders.py +++ b/rest_framework/utils/encoders.py @@ -21,6 +21,7 @@ class JSONEncoder(json.JSONEncoder): JSONEncoder subclass that knows how to encode date/time/timedelta, decimal types, generators and other basic python objects. """ + def default(self, obj): # For Date Time string spec, see ECMA 262 # https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15 @@ -28,8 +29,8 @@ class JSONEncoder(json.JSONEncoder): return force_text(obj) elif isinstance(obj, datetime.datetime): representation = obj.isoformat() - if representation.endswith('+00:00'): - representation = representation[:-6] + 'Z' + if representation.endswith("+00:00"): + representation = representation[:-6] + "Z" return representation elif isinstance(obj, datetime.date): return obj.isoformat() @@ -49,20 +50,22 @@ class JSONEncoder(json.JSONEncoder): return tuple(obj) elif isinstance(obj, bytes): # Best-effort for binary blobs. See #4187. - return obj.decode('utf-8') - elif hasattr(obj, 'tolist'): + return obj.decode("utf-8") + elif hasattr(obj, "tolist"): # Numpy arrays and array scalars. return obj.tolist() - elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)): + elif (coreapi is not None) and isinstance( + obj, (coreapi.Document, coreapi.Error) + ): raise RuntimeError( - 'Cannot return a coreapi object from a JSON view. ' - 'You should be using a schema renderer instead for this view.' + "Cannot return a coreapi object from a JSON view. " + "You should be using a schema renderer instead for this view." ) - elif hasattr(obj, '__getitem__'): + elif hasattr(obj, "__getitem__"): try: return dict(obj) except Exception: pass - elif hasattr(obj, '__iter__'): + elif hasattr(obj, "__iter__"): return tuple(item for item in obj) return super(JSONEncoder, self).default(obj) diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 927d08ff2..55a1fd603 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -11,8 +11,12 @@ from django.utils.text import capfirst from rest_framework.compat import postgres_fields from rest_framework.validators import UniqueValidator + NUMERIC_FIELD_TYPES = ( - models.IntegerField, models.FloatField, models.DecimalField, models.DurationField, + models.IntegerField, + models.FloatField, + models.DecimalField, + models.DurationField, ) @@ -23,11 +27,12 @@ class ClassLookupDict(object): hierarchy in method resolution order, and returns the first matching value from the dictionary or raises a KeyError if nothing matches. """ + def __init__(self, mapping): self.mapping = mapping def __getitem__(self, key): - if hasattr(key, '_proxy_class'): + if hasattr(key, "_proxy_class"): # Deal with proxy classes. Ie. BoundField behaves as if it # is a Field instance when using ClassLookupDict. base_class = key._proxy_class @@ -37,7 +42,7 @@ class ClassLookupDict(object): for cls in inspect.getmro(base_class): if cls in self.mapping: return self.mapping[cls] - raise KeyError('Class %s not found in lookup.' % base_class.__name__) + raise KeyError("Class %s not found in lookup." % base_class.__name__) def __setitem__(self, key, value): self.mapping[key] = value @@ -48,7 +53,7 @@ def needs_label(model_field, field_name): Returns `True` if the label based on the model's verbose name is not equal to the default label it would have based on it's field name. """ - default_label = field_name.replace('_', ' ').capitalize() + default_label = field_name.replace("_", " ").capitalize() return capfirst(model_field.verbose_name) != default_label @@ -57,9 +62,9 @@ def get_detail_view_name(model): Given a model class, return the view name to use for URL relationships that refer to instances of the model. """ - return '%(model_name)s-detail' % { - 'app_label': model._meta.app_label, - 'model_name': model._meta.object_name.lower() + return "%(model_name)s-detail" % { + "app_label": model._meta.app_label, + "model_name": model._meta.object_name.lower(), } @@ -72,84 +77,98 @@ def get_field_kwargs(field_name, model_field): # The following will only be used by ModelField classes. # Gets removed for everything else. - kwargs['model_field'] = model_field + kwargs["model_field"] = model_field if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) + kwargs["label"] = capfirst(model_field.verbose_name) if model_field.help_text: - kwargs['help_text'] = model_field.help_text + kwargs["help_text"] = model_field.help_text - max_digits = getattr(model_field, 'max_digits', None) + max_digits = getattr(model_field, "max_digits", None) if max_digits is not None: - kwargs['max_digits'] = max_digits + kwargs["max_digits"] = max_digits - decimal_places = getattr(model_field, 'decimal_places', None) + decimal_places = getattr(model_field, "decimal_places", None) if decimal_places is not None: - kwargs['decimal_places'] = decimal_places + kwargs["decimal_places"] = decimal_places if isinstance(model_field, models.SlugField): - kwargs['allow_unicode'] = model_field.allow_unicode + kwargs["allow_unicode"] = model_field.allow_unicode - if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)): - kwargs['style'] = {'base_template': 'textarea.html'} + if isinstance(model_field, models.TextField) or ( + postgres_fields and isinstance(model_field, postgres_fields.JSONField) + ): + kwargs["style"] = {"base_template": "textarea.html"} if isinstance(model_field, models.AutoField) or not model_field.editable: # If this field is read-only, then return early. # Further keyword arguments are not valid. - kwargs['read_only'] = True + kwargs["read_only"] = True return kwargs if model_field.has_default() or model_field.blank or model_field.null: - kwargs['required'] = False + kwargs["required"] = False if model_field.null and not isinstance(model_field, models.NullBooleanField): - kwargs['allow_null'] = True + kwargs["allow_null"] = True - if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))): - kwargs['allow_blank'] = True + if model_field.blank and ( + isinstance(model_field, (models.CharField, models.TextField)) + ): + kwargs["allow_blank"] = True if isinstance(model_field, models.FilePathField): - kwargs['path'] = model_field.path + kwargs["path"] = model_field.path if model_field.match is not None: - kwargs['match'] = model_field.match + kwargs["match"] = model_field.match if model_field.recursive is not False: - kwargs['recursive'] = model_field.recursive + kwargs["recursive"] = model_field.recursive if model_field.allow_files is not True: - kwargs['allow_files'] = model_field.allow_files + kwargs["allow_files"] = model_field.allow_files if model_field.allow_folders is not False: - kwargs['allow_folders'] = model_field.allow_folders + kwargs["allow_folders"] = model_field.allow_folders if model_field.choices: - kwargs['choices'] = model_field.choices + kwargs["choices"] = model_field.choices else: # Ensure that max_value is passed explicitly as a keyword arg, # rather than as a validator. - max_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MaxValueValidator) - ), None) + max_value = next( + ( + validator.limit_value + for validator in validator_kwarg + if isinstance(validator, validators.MaxValueValidator) + ), + None, + ) if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): - kwargs['max_value'] = max_value + kwargs["max_value"] = max_value validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.MaxValueValidator) ] # Ensure that min_value is passed explicitly as a keyword arg, # rather than as a validator. - min_value = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MinValueValidator) - ), None) + min_value = next( + ( + validator.limit_value + for validator in validator_kwarg + if isinstance(validator, validators.MinValueValidator) + ), + None, + ) if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES): - kwargs['min_value'] = min_value + kwargs["min_value"] = min_value validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.MinValueValidator) ] @@ -157,7 +176,8 @@ def get_field_kwargs(field_name, model_field): # as it is explicitly added in. if isinstance(model_field, models.URLField): validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.URLValidator) ] @@ -165,67 +185,79 @@ def get_field_kwargs(field_name, model_field): # as it is explicitly added in. if isinstance(model_field, models.EmailField): validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if validator is not validators.validate_email ] # SlugField do not need to include the 'validate_slug' argument, if isinstance(model_field, models.SlugField): validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if validator is not validators.validate_slug ] # IPAddressField do not need to include the 'validate_ipv46_address' argument, if isinstance(model_field, models.GenericIPAddressField): validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if validator is not validators.validate_ipv46_address ] # Our decimal validation is handled in the field code, not validator code. if isinstance(model_field, models.DecimalField): validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.DecimalValidator) ] # Ensure that max_length is passed explicitly as a keyword arg, # rather than as a validator. - max_length = getattr(model_field, 'max_length', None) - if max_length is not None and (isinstance(model_field, (models.CharField, models.TextField, models.FileField))): - kwargs['max_length'] = max_length + max_length = getattr(model_field, "max_length", None) + if max_length is not None and ( + isinstance(model_field, (models.CharField, models.TextField, models.FileField)) + ): + kwargs["max_length"] = max_length validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.MaxLengthValidator) ] # Ensure that min_length is passed explicitly as a keyword arg, # rather than as a validator. - min_length = next(( - validator.limit_value for validator in validator_kwarg - if isinstance(validator, validators.MinLengthValidator) - ), None) + min_length = next( + ( + validator.limit_value + for validator in validator_kwarg + if isinstance(validator, validators.MinLengthValidator) + ), + None, + ) if min_length is not None and isinstance(model_field, models.CharField): - kwargs['min_length'] = min_length + kwargs["min_length"] = min_length validator_kwarg = [ - validator for validator in validator_kwarg + validator + for validator in validator_kwarg if not isinstance(validator, validators.MinLengthValidator) ] - if getattr(model_field, 'unique', False): - unique_error_message = model_field.error_messages.get('unique', None) + if getattr(model_field, "unique", False): + unique_error_message = model_field.error_messages.get("unique", None) if unique_error_message: unique_error_message = unique_error_message % { - 'model_name': model_field.model._meta.verbose_name, - 'field_label': model_field.verbose_name + "model_name": model_field.model._meta.verbose_name, + "field_label": model_field.verbose_name, } validator = UniqueValidator( - queryset=model_field.model._default_manager, - message=unique_error_message) + queryset=model_field.model._default_manager, message=unique_error_message + ) validator_kwarg.append(validator) if validator_kwarg: - kwargs['validators'] = validator_kwarg + kwargs["validators"] = validator_kwarg return kwargs @@ -234,65 +266,65 @@ def get_relation_kwargs(field_name, relation_info): """ Creates a default instance of a flat relational field. """ - model_field, related_model, to_many, to_field, has_through_model, reverse = relation_info + model_field, related_model, to_many, to_field, has_through_model, reverse = ( + relation_info + ) kwargs = { - 'queryset': related_model._default_manager, - 'view_name': get_detail_view_name(related_model) + "queryset": related_model._default_manager, + "view_name": get_detail_view_name(related_model), } if to_many: - kwargs['many'] = True + kwargs["many"] = True if to_field: - kwargs['to_field'] = to_field + kwargs["to_field"] = to_field limit_choices_to = model_field and model_field.get_limit_choices_to() if limit_choices_to: if not isinstance(limit_choices_to, models.Q): limit_choices_to = models.Q(**limit_choices_to) - kwargs['queryset'] = kwargs['queryset'].filter(limit_choices_to) + kwargs["queryset"] = kwargs["queryset"].filter(limit_choices_to) if has_through_model: - kwargs['read_only'] = True - kwargs.pop('queryset', None) + kwargs["read_only"] = True + kwargs.pop("queryset", None) if model_field: if model_field.verbose_name and needs_label(model_field, field_name): - kwargs['label'] = capfirst(model_field.verbose_name) + kwargs["label"] = capfirst(model_field.verbose_name) help_text = model_field.help_text if help_text: - kwargs['help_text'] = help_text + kwargs["help_text"] = help_text if not model_field.editable: - kwargs['read_only'] = True - kwargs.pop('queryset', None) - if kwargs.get('read_only', False): + kwargs["read_only"] = True + kwargs.pop("queryset", None) + if kwargs.get("read_only", False): # If this field is read-only, then return early. # No further keyword arguments are valid. return kwargs if model_field.has_default() or model_field.blank or model_field.null: - kwargs['required'] = False + kwargs["required"] = False if model_field.null: - kwargs['allow_null'] = True + kwargs["allow_null"] = True if model_field.validators: - kwargs['validators'] = model_field.validators - if getattr(model_field, 'unique', False): + kwargs["validators"] = model_field.validators + if getattr(model_field, "unique", False): validator = UniqueValidator(queryset=model_field.model._default_manager) - kwargs['validators'] = kwargs.get('validators', []) + [validator] + kwargs["validators"] = kwargs.get("validators", []) + [validator] if to_many and not model_field.blank: - kwargs['allow_empty'] = False + kwargs["allow_empty"] = False return kwargs def get_nested_relation_kwargs(relation_info): - kwargs = {'read_only': True} + kwargs = {"read_only": True} if relation_info.to_many: - kwargs['many'] = True + kwargs["many"] = True return kwargs def get_url_kwargs(model_field): - return { - 'view_name': get_detail_view_name(model_field) - } + return {"view_name": get_detail_view_name(model_field)} diff --git a/rest_framework/utils/formatting.py b/rest_framework/utils/formatting.py index aa805f14e..0fcfa634a 100644 --- a/rest_framework/utils/formatting.py +++ b/rest_framework/utils/formatting.py @@ -18,7 +18,7 @@ def remove_trailing_string(content, trailing): Used when generating names from view classes. """ if content.endswith(trailing) and content != trailing: - return content[:-len(trailing)] + return content[: -len(trailing)] return content @@ -36,14 +36,14 @@ def dedent(content): # unindent the content if needed if lines: - whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines]) - tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines]) + whitespace_counts = min([len(line) - len(line.lstrip(" ")) for line in lines]) + tab_counts = min([len(line) - len(line.lstrip("\t")) for line in lines]) if whitespace_counts: - whitespace_pattern = '^' + (' ' * whitespace_counts) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + whitespace_pattern = "^" + (" " * whitespace_counts) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content) elif tab_counts: - whitespace_pattern = '^' + ('\t' * tab_counts) - content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) + whitespace_pattern = "^" + ("\t" * tab_counts) + content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content) return content.strip() @@ -52,9 +52,9 @@ def camelcase_to_spaces(content): Translate 'CamelCaseNames' to 'Camel Case Names'. Used when generating names from view classes. """ - camelcase_boundary = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))' - content = re.sub(camelcase_boundary, ' \\1', content).strip() - return ' '.join(content.split('_')).title() + camelcase_boundary = "(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))" + content = re.sub(camelcase_boundary, " \\1", content).strip() + return " ".join(content.split("_")).title() def markup_description(description): @@ -64,6 +64,6 @@ def markup_description(description): if apply_markdown: description = apply_markdown(description) else: - description = escape(description).replace('\n', '
    ') - description = '

    ' + description + '

    ' + description = escape(description).replace("\n", "
    ") + description = "

    " + description + "

    " return mark_safe(description) diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py index c7ede7803..48820327f 100644 --- a/rest_framework/utils/html.py +++ b/rest_framework/utils/html.py @@ -9,10 +9,10 @@ from django.utils.datastructures import MultiValueDict def is_html_input(dictionary): # MultiDict type datastructures are used to represent HTML form input, # which may have more than one value for each key. - return hasattr(dictionary, 'getlist') + return hasattr(dictionary, "getlist") -def parse_html_list(dictionary, prefix='', default=None): +def parse_html_list(dictionary, prefix="", default=None): """ Used to support list values in HTML forms. Supports lists of primitives and/or dictionaries. @@ -48,7 +48,7 @@ def parse_html_list(dictionary, prefix='', default=None): :returns a list of objects, or the value specified in ``default`` if the list is empty """ ret = {} - regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix)) + regex = re.compile(r"^%s\[([0-9]+)\](.*)$" % re.escape(prefix)) for field, value in dictionary.items(): match = regex.match(field) if not match: @@ -66,7 +66,7 @@ def parse_html_list(dictionary, prefix='', default=None): return [ret[item] for item in sorted(ret)] if ret else default -def parse_html_dict(dictionary, prefix=''): +def parse_html_dict(dictionary, prefix=""): """ Used to support dictionary values in HTML forms. @@ -83,7 +83,7 @@ def parse_html_dict(dictionary, prefix=''): } """ ret = MultiValueDict() - regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix)) + regex = re.compile(r"^%s\.(.+)$" % re.escape(prefix)) for field in dictionary: match = regex.match(field) if not match: diff --git a/rest_framework/utils/humanize_datetime.py b/rest_framework/utils/humanize_datetime.py index 48ef89547..14ca1f663 100644 --- a/rest_framework/utils/humanize_datetime.py +++ b/rest_framework/utils/humanize_datetime.py @@ -5,20 +5,19 @@ from rest_framework import ISO_8601 def datetime_formats(formats): - format = ', '.join(formats).replace( - ISO_8601, - 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]' + format = ", ".join(formats).replace( + ISO_8601, "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]" ) return humanize_strptime(format) def date_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DD') + format = ", ".join(formats).replace(ISO_8601, "YYYY-MM-DD") return humanize_strptime(format) def time_formats(formats): - format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]') + format = ", ".join(formats).replace(ISO_8601, "hh:mm[:ss[.uuuuuu]]") return humanize_strptime(format) @@ -40,7 +39,7 @@ def humanize_strptime(format_string): "%a": "[Mon-Sun]", "%A": "[Monday-Sunday]", "%p": "[AM|PM]", - "%z": "[+HHMM|-HHMM]" + "%z": "[+HHMM|-HHMM]", } for key, val in mapping.items(): format_string = format_string.replace(key, val) diff --git a/rest_framework/utils/json.py b/rest_framework/utils/json.py index cb5572380..09ba12fe2 100644 --- a/rest_framework/utils/json.py +++ b/rest_framework/utils/json.py @@ -13,28 +13,28 @@ import json # noqa def strict_constant(o): - raise ValueError('Out of range float values are not JSON compliant: ' + repr(o)) + raise ValueError("Out of range float values are not JSON compliant: " + repr(o)) @functools.wraps(json.dump) def dump(*args, **kwargs): - kwargs.setdefault('allow_nan', False) + kwargs.setdefault("allow_nan", False) return json.dump(*args, **kwargs) @functools.wraps(json.dumps) def dumps(*args, **kwargs): - kwargs.setdefault('allow_nan', False) + kwargs.setdefault("allow_nan", False) return json.dumps(*args, **kwargs) @functools.wraps(json.load) def load(*args, **kwargs): - kwargs.setdefault('parse_constant', strict_constant) + kwargs.setdefault("parse_constant", strict_constant) return json.load(*args, **kwargs) @functools.wraps(json.loads) def loads(*args, **kwargs): - kwargs.setdefault('parse_constant', strict_constant) + kwargs.setdefault("parse_constant", strict_constant) return json.loads(*args, **kwargs) diff --git a/rest_framework/utils/mediatypes.py b/rest_framework/utils/mediatypes.py index f4acf4807..2a5752d8f 100644 --- a/rest_framework/utils/mediatypes.py +++ b/rest_framework/utils/mediatypes.py @@ -49,20 +49,30 @@ def order_by_precedence(media_type_lst): @python_2_unicode_compatible class _MediaType(object): def __init__(self, media_type_str): - self.orig = '' if (media_type_str is None) else media_type_str - self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING)) - self.main_type, sep, self.sub_type = self.full_type.partition('/') + self.orig = "" if (media_type_str is None) else media_type_str + self.full_type, self.params = parse_header( + self.orig.encode(HTTP_HEADER_ENCODING) + ) + self.main_type, sep, self.sub_type = self.full_type.partition("/") def match(self, other): """Return true if this MediaType satisfies the given MediaType.""" for key in self.params: - if key != 'q' and other.params.get(key, None) != self.params.get(key, None): + if key != "q" and other.params.get(key, None) != self.params.get(key, None): return False - if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: + if ( + self.sub_type != "*" + and other.sub_type != "*" + and other.sub_type != self.sub_type + ): return False - if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: + if ( + self.main_type != "*" + and other.main_type != "*" + and other.main_type != self.main_type + ): return False return True @@ -72,16 +82,16 @@ class _MediaType(object): """ Return a precedence level from 0-3 for the media type given how specific it is. """ - if self.main_type == '*': + if self.main_type == "*": return 0 - elif self.sub_type == '*': + elif self.sub_type == "*": return 1 - elif not self.params or list(self.params) == ['q']: + elif not self.params or list(self.params) == ["q"]: return 2 return 3 def __str__(self): ret = "%s/%s" % (self.main_type, self.sub_type) for key, val in self.params.items(): - ret += "; %s=%s" % (key, val.decode('ascii')) + ret += "; %s=%s" % (key, val.decode("ascii")) return ret diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 4cc93b8ef..1bc8f1551 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -7,23 +7,30 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import OrderedDict, namedtuple -FieldInfo = namedtuple('FieldResult', [ - 'pk', # Model field instance - 'fields', # Dict of field name -> model field instance - 'forward_relations', # Dict of field name -> RelationInfo - 'reverse_relations', # Dict of field name -> RelationInfo - 'fields_and_pk', # Shortcut for 'pk' + 'fields' - 'relations' # Shortcut for 'forward_relations' + 'reverse_relations' -]) -RelationInfo = namedtuple('RelationInfo', [ - 'model_field', - 'related_model', - 'to_many', - 'to_field', - 'has_through_model', - 'reverse' -]) +FieldInfo = namedtuple( + "FieldResult", + [ + "pk", # Model field instance + "fields", # Dict of field name -> model field instance + "forward_relations", # Dict of field name -> RelationInfo + "reverse_relations", # Dict of field name -> RelationInfo + "fields_and_pk", # Shortcut for 'pk' + 'fields' + "relations", # Shortcut for 'forward_relations' + 'reverse_relations' + ], +) + +RelationInfo = namedtuple( + "RelationInfo", + [ + "model_field", + "related_model", + "to_many", + "to_field", + "has_through_model", + "reverse", + ], +) def get_field_info(model): @@ -41,8 +48,9 @@ def get_field_info(model): fields_and_pk = _merge_fields_and_pk(pk, fields) relationships = _merge_relationships(forward_relations, reverse_relations) - return FieldInfo(pk, fields, forward_relations, reverse_relations, - fields_and_pk, relationships) + return FieldInfo( + pk, fields, forward_relations, reverse_relations, fields_and_pk, relationships + ) def _get_pk(opts): @@ -59,14 +67,16 @@ def _get_pk(opts): def _get_fields(opts): fields = OrderedDict() - for field in [field for field in opts.fields if field.serialize and not field.remote_field]: + for field in [ + field for field in opts.fields if field.serialize and not field.remote_field + ]: fields[field.name] = field return fields def _get_to_field(field): - return getattr(field, 'to_fields', None) and field.to_fields[0] + return getattr(field, "to_fields", None) and field.to_fields[0] def _get_forward_relationships(opts): @@ -74,14 +84,16 @@ def _get_forward_relationships(opts): Returns an `OrderedDict` of field names to `RelationInfo`. """ forward_relations = OrderedDict() - for field in [field for field in opts.fields if field.serialize and field.remote_field]: + for field in [ + field for field in opts.fields if field.serialize and field.remote_field + ]: forward_relations[field.name] = RelationInfo( model_field=field, related_model=field.remote_field.model, to_many=False, to_field=_get_to_field(field), has_through_model=False, - reverse=False + reverse=False, ) # Deal with forward many-to-many relationships. @@ -92,10 +104,8 @@ def _get_forward_relationships(opts): to_many=True, # manytomany do not have to_fields to_field=None, - has_through_model=( - not field.remote_field.through._meta.auto_created - ), - reverse=False + has_through_model=(not field.remote_field.through._meta.auto_created), + reverse=False, ) return forward_relations @@ -115,11 +125,13 @@ def _get_reverse_relationships(opts): to_many=relation.field.remote_field.multiple, to_field=_get_to_field(relation.field), has_through_model=False, - reverse=True + reverse=True, ) # Deal with reverse many-to-many relationships. - all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many] + all_related_many_to_many_objects = [ + r for r in opts.related_objects if r.field.many_to_many + ] for relation in all_related_many_to_many_objects: accessor_name = relation.get_accessor_name() reverse_relations[accessor_name] = RelationInfo( @@ -129,10 +141,10 @@ def _get_reverse_relationships(opts): # manytomany do not have to_fields to_field=None, has_through_model=( - (getattr(relation.field.remote_field, 'through', None) is not None) and - not relation.field.remote_field.through._meta.auto_created + (getattr(relation.field.remote_field, "through", None) is not None) + and not relation.field.remote_field.through._meta.auto_created ), - reverse=True + reverse=True, ) return reverse_relations @@ -140,7 +152,7 @@ def _get_reverse_relationships(opts): def _merge_fields_and_pk(pk, fields): fields_and_pk = OrderedDict() - fields_and_pk['pk'] = pk + fields_and_pk["pk"] = pk fields_and_pk[pk.name] = pk fields_and_pk.update(fields) @@ -149,8 +161,7 @@ def _merge_fields_and_pk(pk, fields): def _merge_relationships(forward_relations, reverse_relations): return OrderedDict( - list(forward_relations.items()) + - list(reverse_relations.items()) + list(forward_relations.items()) + list(reverse_relations.items()) ) @@ -158,4 +169,8 @@ def is_abstract_model(model): """ Given a model class, returns a boolean True if it is abstract and False if it is not. """ - return hasattr(model, '_meta') and hasattr(model._meta, 'abstract') and model._meta.abstract + return ( + hasattr(model, "_meta") + and hasattr(model._meta, "abstract") + and model._meta.abstract + ) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index deeaf1f63..43a1bb3c2 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -16,14 +16,10 @@ from rest_framework.compat import unicode_repr def manager_repr(value): model = value.model opts = model._meta - names_and_managers = [ - (manager.name, manager) - for manager - in opts.managers - ] + names_and_managers = [(manager.name, manager) for manager in opts.managers] for manager_name, manager_instance in names_and_managers: if manager_instance == value: - return '%s.%s.all()' % (model._meta.object_name, manager_name) + return "%s.%s.all()" % (model._meta.object_name, manager_name) return repr(value) @@ -45,7 +41,7 @@ def smart_repr(value): # # Should be presented as # - value = re.sub(' at 0x[0-9A-Fa-f]{4,32}>', '>', value) + value = re.sub(" at 0x[0-9A-Fa-f]{4,32}>", ">", value) return value @@ -54,16 +50,15 @@ def field_repr(field, force_many=False): kwargs = field._kwargs if force_many: kwargs = kwargs.copy() - kwargs['many'] = True - kwargs.pop('child', None) + kwargs["many"] = True + kwargs.pop("child", None) - arg_string = ', '.join([smart_repr(val) for val in field._args]) - kwarg_string = ', '.join([ - '%s=%s' % (key, smart_repr(val)) - for key, val in sorted(kwargs.items()) - ]) + arg_string = ", ".join([smart_repr(val) for val in field._args]) + kwarg_string = ", ".join( + ["%s=%s" % (key, smart_repr(val)) for key, val in sorted(kwargs.items())] + ) if arg_string and kwarg_string: - arg_string += ', ' + arg_string += ", " if force_many: class_name = force_many.__class__.__name__ @@ -74,8 +69,8 @@ def field_repr(field, force_many=False): def serializer_repr(serializer, indent, force_many=None): - ret = field_repr(serializer, force_many) + ':' - indent_str = ' ' * indent + ret = field_repr(serializer, force_many) + ":" + indent_str = " " * indent if force_many: fields = force_many.fields @@ -83,25 +78,27 @@ def serializer_repr(serializer, indent, force_many=None): fields = serializer.fields for field_name, field in fields.items(): - ret += '\n' + indent_str + field_name + ' = ' - if hasattr(field, 'fields'): + ret += "\n" + indent_str + field_name + " = " + if hasattr(field, "fields"): ret += serializer_repr(field, indent + 1) - elif hasattr(field, 'child'): + elif hasattr(field, "child"): ret += list_repr(field, indent + 1) - elif hasattr(field, 'child_relation'): + elif hasattr(field, "child_relation"): ret += field_repr(field.child_relation, force_many=field.child_relation) else: ret += field_repr(field) if serializer.validators: - ret += '\n' + indent_str + 'class Meta:' - ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators) + ret += "\n" + indent_str + "class Meta:" + ret += ( + "\n" + indent_str + " validators = " + smart_repr(serializer.validators) + ) return ret def list_repr(serializer, indent): child = serializer.child - if hasattr(child, 'fields'): + if hasattr(child, "fields"): return serializer_repr(serializer, indent, force_many=child) return field_repr(serializer) diff --git a/rest_framework/utils/serializer_helpers.py b/rest_framework/utils/serializer_helpers.py index c24e51d09..e2db8c130 100644 --- a/rest_framework/utils/serializer_helpers.py +++ b/rest_framework/utils/serializer_helpers.py @@ -16,7 +16,7 @@ class ReturnDict(OrderedDict): """ def __init__(self, *args, **kwargs): - self.serializer = kwargs.pop('serializer') + self.serializer = kwargs.pop("serializer") super(ReturnDict, self).__init__(*args, **kwargs) def copy(self): @@ -39,7 +39,7 @@ class ReturnList(list): """ def __init__(self, *args, **kwargs): - self.serializer = kwargs.pop('serializer') + self.serializer = kwargs.pop("serializer") super(ReturnList, self).__init__(*args, **kwargs) def __repr__(self): @@ -58,7 +58,7 @@ class BoundField(object): providing an API similar to Django forms and form fields. """ - def __init__(self, field, value, errors, prefix=''): + def __init__(self, field, value, errors, prefix=""): self._field = field self._prefix = prefix self.value = value @@ -73,12 +73,13 @@ class BoundField(object): return self._field.__class__ def __repr__(self): - return unicode_to_repr('<%s value=%s errors=%s>' % ( - self.__class__.__name__, self.value, self.errors - )) + return unicode_to_repr( + "<%s value=%s errors=%s>" + % (self.__class__.__name__, self.value, self.errors) + ) def as_form_field(self): - value = '' if (self.value is None or self.value is False) else self.value + value = "" if (self.value is None or self.value is False) else self.value return self.__class__(self._field, value, self.errors, self._prefix) @@ -87,7 +88,7 @@ class JSONBoundField(BoundField): value = self.value # When HTML form input is used and the input is not valid # value will be a JSONString, rather than a JSON primitive. - if not getattr(value, 'is_json_string', False): + if not getattr(value, "is_json_string", False): try: value = json.dumps(self.value, sort_keys=True, indent=4) except (TypeError, ValueError): @@ -102,8 +103,8 @@ class NestedBoundField(BoundField): `BoundField` that is used for serializer fields. """ - def __init__(self, field, value, errors, prefix=''): - if value is None or value is '': + def __init__(self, field, value, errors, prefix=""): + if value is None or value is "": value = {} super(NestedBoundField, self).__init__(field, value, errors, prefix) @@ -115,9 +116,9 @@ class NestedBoundField(BoundField): field = self.fields[key] value = self.value.get(key) if self.value else None error = self.errors.get(key) if isinstance(self.errors, dict) else None - if hasattr(field, 'fields'): - return NestedBoundField(field, value, error, prefix=self.name + '.') - return BoundField(field, value, error, prefix=self.name + '.') + if hasattr(field, "fields"): + return NestedBoundField(field, value, error, prefix=self.name + ".") + return BoundField(field, value, error, prefix=self.name + ".") def as_form_field(self): values = {} @@ -125,7 +126,9 @@ class NestedBoundField(BoundField): if isinstance(value, (list, dict)): values[key] = value else: - values[key] = '' if (value is None or value is False) else force_text(value) + values[key] = ( + "" if (value is None or value is False) else force_text(value) + ) return self.__class__(self._field, values, self.errors, self._prefix) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 2ea3e5ac1..b2b06adf7 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -39,9 +39,10 @@ class UniqueValidator(object): Should be applied to an individual field on the serializer. """ - message = _('This field must be unique.') - def __init__(self, queryset, message=None, lookup='exact'): + message = _("This field must be unique.") + + def __init__(self, queryset, message=None, lookup="exact"): self.queryset = queryset self.serializer_field = None self.message = message or self.message @@ -56,13 +57,13 @@ class UniqueValidator(object): # same as the serializer field name if `source=<>` is set. self.field_name = serializer_field.source_attrs[-1] # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer_field.parent, 'instance', None) + self.instance = getattr(serializer_field.parent, "instance", None) def filter_queryset(self, value, queryset): """ Filter the queryset to all instances matching the given attribute. """ - filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} + filter_kwargs = {"%s__%s" % (self.field_name, self.lookup): value} return qs_filter(queryset, **filter_kwargs) def exclude_current_instance(self, queryset): @@ -79,13 +80,12 @@ class UniqueValidator(object): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if qs_exists(queryset): - raise ValidationError(self.message, code='unique') + raise ValidationError(self.message, code="unique") def __repr__(self): - return unicode_to_repr('<%s(queryset=%s)>' % ( - self.__class__.__name__, - smart_repr(self.queryset) - )) + return unicode_to_repr( + "<%s(queryset=%s)>" % (self.__class__.__name__, smart_repr(self.queryset)) + ) class UniqueTogetherValidator(object): @@ -94,8 +94,9 @@ class UniqueTogetherValidator(object): Should be applied to the serializer class, not to an individual field. """ - message = _('The fields {field_names} must make a unique set.') - missing_message = _('This field is required.') + + message = _("The fields {field_names} must make a unique set.") + missing_message = _("This field is required.") def __init__(self, queryset, fields, message=None): self.queryset = queryset @@ -109,7 +110,7 @@ class UniqueTogetherValidator(object): prior to the validation call being made. """ # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) + self.instance = getattr(serializer, "instance", None) def enforce_required_fields(self, attrs): """ @@ -125,7 +126,7 @@ class UniqueTogetherValidator(object): if field_name not in attrs } if missing_items: - raise ValidationError(missing_items, code='required') + raise ValidationError(missing_items, code="required") def filter_queryset(self, attrs, queryset): """ @@ -139,10 +140,7 @@ class UniqueTogetherValidator(object): attrs[field_name] = getattr(self.instance, field_name) # Determine the filter keyword arguments and filter the queryset. - filter_kwargs = { - field_name: attrs[field_name] - for field_name in self.fields - } + filter_kwargs = {field_name: attrs[field_name] for field_name in self.fields} return qs_filter(queryset, **filter_kwargs) def exclude_current_instance(self, attrs, queryset): @@ -165,21 +163,24 @@ class UniqueTogetherValidator(object): value for field, value in attrs.items() if field in self.fields ] if None not in checked_values and qs_exists(queryset): - field_names = ', '.join(self.fields) + field_names = ", ".join(self.fields) message = self.message.format(field_names=field_names) - raise ValidationError(message, code='unique') + raise ValidationError(message, code="unique") def __repr__(self): - return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( - self.__class__.__name__, - smart_repr(self.queryset), - smart_repr(self.fields) - )) + return unicode_to_repr( + "<%s(queryset=%s, fields=%s)>" + % ( + self.__class__.__name__, + smart_repr(self.queryset), + smart_repr(self.fields), + ) + ) class BaseUniqueForValidator(object): message = None - missing_message = _('This field is required.') + missing_message = _("This field is required.") def __init__(self, queryset, field, date_field, message=None): self.queryset = queryset @@ -197,7 +198,7 @@ class BaseUniqueForValidator(object): self.field_name = serializer.fields[self.field].source_attrs[-1] self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) + self.instance = getattr(serializer, "instance", None) def enforce_required_fields(self, attrs): """ @@ -210,10 +211,10 @@ class BaseUniqueForValidator(object): if field_name not in attrs } if missing_items: - raise ValidationError(missing_items, code='required') + raise ValidationError(missing_items, code="required") def filter_queryset(self, attrs, queryset): - raise NotImplementedError('`filter_queryset` must be implemented.') + raise NotImplementedError("`filter_queryset` must be implemented.") def exclude_current_instance(self, attrs, queryset): """ @@ -231,17 +232,18 @@ class BaseUniqueForValidator(object): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - raise ValidationError({ - self.field: message - }, code='unique') + raise ValidationError({self.field: message}, code="unique") def __repr__(self): - return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( - self.__class__.__name__, - smart_repr(self.queryset), - smart_repr(self.field), - smart_repr(self.date_field) - )) + return unicode_to_repr( + "<%s(queryset=%s, field=%s, date_field=%s)>" + % ( + self.__class__.__name__, + smart_repr(self.queryset), + smart_repr(self.field), + smart_repr(self.date_field), + ) + ) class UniqueForDateValidator(BaseUniqueForValidator): @@ -253,9 +255,9 @@ class UniqueForDateValidator(BaseUniqueForValidator): filter_kwargs = {} filter_kwargs[self.field_name] = value - filter_kwargs['%s__day' % self.date_field_name] = date.day - filter_kwargs['%s__month' % self.date_field_name] = date.month - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs["%s__day" % self.date_field_name] = date.day + filter_kwargs["%s__month" % self.date_field_name] = date.month + filter_kwargs["%s__year" % self.date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) @@ -268,7 +270,7 @@ class UniqueForMonthValidator(BaseUniqueForValidator): filter_kwargs = {} filter_kwargs[self.field_name] = value - filter_kwargs['%s__month' % self.date_field_name] = date.month + filter_kwargs["%s__month" % self.date_field_name] = date.month return qs_filter(queryset, **filter_kwargs) @@ -281,5 +283,5 @@ class UniqueForYearValidator(BaseUniqueForValidator): filter_kwargs = {} filter_kwargs[self.field_name] = value - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs["%s__year" % self.date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) diff --git a/rest_framework/versioning.py b/rest_framework/versioning.py index 206ff6c2e..4e3bf855f 100644 --- a/rest_framework/versioning.py +++ b/rest_framework/versioning.py @@ -19,19 +19,20 @@ class BaseVersioning(object): version_param = api_settings.VERSION_PARAM def determine_version(self, request, *args, **kwargs): - msg = '{cls}.determine_version() must be implemented.' - raise NotImplementedError(msg.format( - cls=self.__class__.__name__ - )) + msg = "{cls}.determine_version() must be implemented." + raise NotImplementedError(msg.format(cls=self.__class__.__name__)) - def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + def reverse( + self, viewname, args=None, kwargs=None, request=None, format=None, **extra + ): return _reverse(viewname, args, kwargs, request, format, **extra) def is_allowed_version(self, version): if not self.allowed_versions: return True - return ((version is not None and version == self.default_version) or - (version in self.allowed_versions)) + return (version is not None and version == self.default_version) or ( + version in self.allowed_versions + ) class AcceptHeaderVersioning(BaseVersioning): @@ -40,6 +41,7 @@ class AcceptHeaderVersioning(BaseVersioning): Host: example.com Accept: application/json; version=1.0 """ + invalid_version_message = _('Invalid version in "Accept" header.') def determine_version(self, request, *args, **kwargs): @@ -71,7 +73,8 @@ class URLPathVersioning(BaseVersioning): Host: example.com Accept: application/json """ - invalid_version_message = _('Invalid version in URL path.') + + invalid_version_message = _("Invalid version in URL path.") def determine_version(self, request, *args, **kwargs): version = kwargs.get(self.version_param, self.default_version) @@ -82,7 +85,9 @@ class URLPathVersioning(BaseVersioning): raise exceptions.NotFound(self.invalid_version_message) return version - def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + def reverse( + self, viewname, args=None, kwargs=None, request=None, format=None, **extra + ): if request.version is not None: kwargs = {} if (kwargs is None) else kwargs kwargs[self.version_param] = request.version @@ -116,21 +121,26 @@ class NamespaceVersioning(BaseVersioning): Host: example.com Accept: application/json """ - invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.') + + invalid_version_message = _( + "Invalid version in URL path. Does not match any version namespace." + ) def determine_version(self, request, *args, **kwargs): - resolver_match = getattr(request, 'resolver_match', None) + resolver_match = getattr(request, "resolver_match", None) if resolver_match is None or not resolver_match.namespace: return self.default_version # Allow for possibly nested namespaces. - possible_versions = resolver_match.namespace.split(':') + possible_versions = resolver_match.namespace.split(":") for version in possible_versions: if self.is_allowed_version(version): return version raise exceptions.NotFound(self.invalid_version_message) - def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + def reverse( + self, viewname, args=None, kwargs=None, request=None, format=None, **extra + ): if request.version is not None: viewname = self.get_versioned_viewname(viewname, request) return super(NamespaceVersioning, self).reverse( @@ -138,7 +148,7 @@ class NamespaceVersioning(BaseVersioning): ) def get_versioned_viewname(self, viewname, request): - return request.version + ':' + viewname + return request.version + ":" + viewname class HostNameVersioning(BaseVersioning): @@ -147,11 +157,12 @@ class HostNameVersioning(BaseVersioning): Host: v1.example.com Accept: application/json """ - hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$') - invalid_version_message = _('Invalid version in hostname.') + + hostname_regex = re.compile(r"^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$") + invalid_version_message = _("Invalid version in hostname.") def determine_version(self, request, *args, **kwargs): - hostname, separator, port = request.get_host().partition(':') + hostname, separator, port = request.get_host().partition(":") match = self.hostname_regex.match(hostname) if not match: return self.default_version @@ -170,7 +181,8 @@ class QueryParameterVersioning(BaseVersioning): Host: example.com Accept: application/json """ - invalid_version_message = _('Invalid version in query parameter.') + + invalid_version_message = _("Invalid version in query parameter.") def determine_version(self, request, *args, **kwargs): version = request.query_params.get(self.version_param, self.default_version) @@ -178,7 +190,9 @@ class QueryParameterVersioning(BaseVersioning): raise exceptions.NotFound(self.invalid_version_message) return version - def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): + def reverse( + self, viewname, args=None, kwargs=None, request=None, format=None, **extra + ): url = super(QueryParameterVersioning, self).reverse( viewname, args, kwargs, request, format, **extra ) diff --git a/rest_framework/views.py b/rest_framework/views.py index 9d5d959e9..d4c006dfe 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -29,19 +29,19 @@ def get_view_name(view): This function is the default for the `VIEW_NAME_FUNCTION` setting. """ # Name may be set by some Views, such as a ViewSet. - name = getattr(view, 'name', None) + name = getattr(view, "name", None) if name is not None: return name name = view.__class__.__name__ - name = formatting.remove_trailing_string(name, 'View') - name = formatting.remove_trailing_string(name, 'ViewSet') + name = formatting.remove_trailing_string(name, "View") + name = formatting.remove_trailing_string(name, "ViewSet") name = formatting.camelcase_to_spaces(name) # Suffix may be set by some Views, such as a ViewSet. - suffix = getattr(view, 'suffix', None) + suffix = getattr(view, "suffix", None) if suffix: - name += ' ' + suffix + name += " " + suffix return name @@ -54,9 +54,9 @@ def get_view_description(view, html=False): This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. """ # Description may be set by some Views, such as a ViewSet. - description = getattr(view, 'description', None) + description = getattr(view, "description", None) if description is None: - description = view.__class__.__doc__ or '' + description = view.__class__.__doc__ or "" description = formatting.dedent(smart_text(description)) if html: @@ -65,7 +65,7 @@ def get_view_description(view, html=False): def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) + atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False) if atomic_requests and connection.in_atomic_block: transaction.set_rollback(True) @@ -87,15 +87,15 @@ def exception_handler(exc, context): if isinstance(exc, exceptions.APIException): headers = {} - if getattr(exc, 'auth_header', None): - headers['WWW-Authenticate'] = exc.auth_header - if getattr(exc, 'wait', None): - headers['Retry-After'] = '%d' % exc.wait + if getattr(exc, "auth_header", None): + headers["WWW-Authenticate"] = exc.auth_header + if getattr(exc, "wait", None): + headers["Retry-After"] = "%d" % exc.wait if isinstance(exc.detail, (list, dict)): data = exc.detail else: - data = {'detail': exc.detail} + data = {"detail": exc.detail} set_rollback() return Response(data, status=exc.status_code, headers=headers) @@ -128,13 +128,15 @@ class APIView(View): This allows us to discover information about the view when we do URL reverse lookups. Used for breadcrumb generation. """ - if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet): + if isinstance(getattr(cls, "queryset", None), models.query.QuerySet): + def force_evaluation(): raise RuntimeError( - 'Do not evaluate the `.queryset` attribute directly, ' - 'as the result will be cached and reused between requests. ' - 'Use `.all()` or call `.get_queryset()` instead.' + "Do not evaluate the `.queryset` attribute directly, " + "as the result will be cached and reused between requests. " + "Use `.all()` or call `.get_queryset()` instead." ) + cls.queryset._fetch_all = force_evaluation view = super(APIView, cls).as_view(**initkwargs) @@ -154,11 +156,9 @@ class APIView(View): @property def default_response_headers(self): - headers = { - 'Allow': ', '.join(self.allowed_methods), - } + headers = {"Allow": ", ".join(self.allowed_methods)} if len(self.renderer_classes) > 1: - headers['Vary'] = 'Accept' + headers["Vary"] = "Accept" return headers def http_method_not_allowed(self, request, *args, **kwargs): @@ -199,9 +199,9 @@ class APIView(View): # Note: Additionally `request` and `encoding` will also be added # to the context by the Request object. return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), } def get_renderer_context(self): @@ -212,10 +212,10 @@ class APIView(View): # Note: Additionally 'response' will also be added to the context, # by the Response object. return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), + "request": getattr(self, "request", None), } def get_exception_handler_context(self): @@ -224,10 +224,10 @@ class APIView(View): as the `context` argument. """ return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), + "request": getattr(self, "request", None), } def get_view_name(self): @@ -289,7 +289,7 @@ class APIView(View): """ Instantiate and return the content negotiation class to use. """ - if not getattr(self, '_negotiator', None): + if not getattr(self, "_negotiator", None): self._negotiator = self.content_negotiation_class() return self._negotiator @@ -333,7 +333,7 @@ class APIView(View): for permission in self.get_permissions(): if not permission.has_permission(request, self): self.permission_denied( - request, message=getattr(permission, 'message', None) + request, message=getattr(permission, "message", None) ) def check_object_permissions(self, request, obj): @@ -344,7 +344,7 @@ class APIView(View): for permission in self.get_permissions(): if not permission.has_object_permission(request, self, obj): self.permission_denied( - request, message=getattr(permission, 'message', None) + request, message=getattr(permission, "message", None) ) def check_throttles(self, request): @@ -379,7 +379,7 @@ class APIView(View): parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), - parser_context=parser_context + parser_context=parser_context, ) def initial(self, request, *args, **kwargs): @@ -407,13 +407,12 @@ class APIView(View): """ # Make the error obvious if a proper response is not returned assert isinstance(response, HttpResponseBase), ( - 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` ' - 'to be returned from the view, but received a `%s`' - % type(response) + "Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` " + "to be returned from the view, but received a `%s`" % type(response) ) if isinstance(response, Response): - if not getattr(request, 'accepted_renderer', None): + if not getattr(request, "accepted_renderer", None): neg = self.perform_content_negotiation(request, force=True) request.accepted_renderer, request.accepted_media_type = neg @@ -422,7 +421,7 @@ class APIView(View): response.renderer_context = self.get_renderer_context() # Add new vary headers to the response instead of overwriting. - vary_headers = self.headers.pop('Vary', None) + vary_headers = self.headers.pop("Vary", None) if vary_headers is not None: patch_vary_headers(response, cc_delim_re.split(vary_headers)) @@ -436,8 +435,9 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, (exceptions.NotAuthenticated, - exceptions.AuthenticationFailed)): + if isinstance( + exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed) + ): # WWW-Authenticate header for 401 responses, else coerce to 403 auth_header = self.get_authenticate_header(self.request) @@ -460,8 +460,8 @@ class APIView(View): def raise_uncaught_exception(self, exc): if settings.DEBUG: request = self.request - renderer_format = getattr(request.accepted_renderer, 'format') - use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') + renderer_format = getattr(request.accepted_renderer, "format") + use_plaintext_traceback = renderer_format not in ("html", "api", "admin") request.force_plaintext_errors(use_plaintext_traceback) raise exc @@ -484,8 +484,9 @@ class APIView(View): # Get the appropriate handler method if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) + handler = getattr( + self, request.method.lower(), self.http_method_not_allowed + ) else: handler = self.http_method_not_allowed diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index 7146828d2..fd57789a8 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -31,7 +31,7 @@ from rest_framework.reverse import reverse def _is_extra_action(attr): - return hasattr(attr, 'mapping') + return hasattr(attr, "mapping") class ViewSetMixin(object): @@ -73,24 +73,30 @@ class ViewSetMixin(object): # actions must not be empty if not actions: - raise TypeError("The `actions` argument must be provided when " - "calling `.as_view()` on a ViewSet. For example " - "`.as_view({'get': 'list'})`") + raise TypeError( + "The `actions` argument must be provided when " + "calling `.as_view()` on a ViewSet. For example " + "`.as_view({'get': 'list'})`" + ) # sanitize keyword arguments for key in initkwargs: if key in cls.http_method_names: - raise TypeError("You tried to pass in the %s method name as a " - "keyword argument to %s(). Don't do that." - % (key, cls.__name__)) + raise TypeError( + "You tried to pass in the %s method name as a " + "keyword argument to %s(). Don't do that." % (key, cls.__name__) + ) if not hasattr(cls, key): - raise TypeError("%s() received an invalid keyword %r" % ( - cls.__name__, key)) + raise TypeError( + "%s() received an invalid keyword %r" % (cls.__name__, key) + ) # name and suffix are mutually exclusive - if 'name' in initkwargs and 'suffix' in initkwargs: - raise TypeError("%s() received both `name` and `suffix`, which are " - "mutually exclusive arguments." % (cls.__name__)) + if "name" in initkwargs and "suffix" in initkwargs: + raise TypeError( + "%s() received both `name` and `suffix`, which are " + "mutually exclusive arguments." % (cls.__name__) + ) def view(request, *args, **kwargs): self = cls(**initkwargs) @@ -105,7 +111,7 @@ class ViewSetMixin(object): handler = getattr(self, action) setattr(self, method, handler) - if hasattr(self, 'get') and not hasattr(self, 'head'): + if hasattr(self, "get") and not hasattr(self, "head"): self.head = self.get self.request = request @@ -136,11 +142,11 @@ class ViewSetMixin(object): """ request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs) method = request.method.lower() - if method == 'options': + if method == "options": # This is a special case as we always provide handling for the # options method in the base `View` class. # Unlike the other explicitly defined actions, 'metadata' is implicit. - self.action = 'metadata' + self.action = "metadata" else: self.action = self.action_map.get(method) return request @@ -149,8 +155,8 @@ class ViewSetMixin(object): """ Reverse the action for the given `url_name`. """ - url_name = '%s-%s' % (self.basename, url_name) - kwargs.setdefault('request', self.request) + url_name = "%s-%s" % (self.basename, url_name) + kwargs.setdefault("request", self.request) return reverse(url_name, *args, **kwargs) @@ -175,13 +181,14 @@ class ViewSetMixin(object): # filter for the relevant extra actions actions = [ - action for action in self.get_extra_actions() + action + for action in self.get_extra_actions() if action.detail == self.detail ] for action in actions: try: - url_name = '%s-%s' % (self.basename, action.url_name) + url_name = "%s-%s" % (self.basename, action.url_name) url = reverse(url_name, self.args, self.kwargs, request=self.request) view = self.__class__(**action.kwargs) action_urls[view.get_view_name()] = url @@ -195,6 +202,7 @@ class ViewSet(ViewSetMixin, views.APIView): """ The base ViewSet class does not provide any actions by default. """ + pass @@ -204,26 +212,31 @@ class GenericViewSet(ViewSetMixin, generics.GenericAPIView): but does include the base set of generic view behavior, such as the `get_object` and `get_queryset` methods. """ + pass -class ReadOnlyModelViewSet(mixins.RetrieveModelMixin, - mixins.ListModelMixin, - GenericViewSet): +class ReadOnlyModelViewSet( + mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet +): """ A viewset that provides default `list()` and `retrieve()` actions. """ + pass -class ModelViewSet(mixins.CreateModelMixin, - mixins.RetrieveModelMixin, - mixins.UpdateModelMixin, - mixins.DestroyModelMixin, - mixins.ListModelMixin, - GenericViewSet): +class ModelViewSet( + mixins.CreateModelMixin, + mixins.RetrieveModelMixin, + mixins.UpdateModelMixin, + mixins.DestroyModelMixin, + mixins.ListModelMixin, + GenericViewSet, +): """ A viewset that provides default `create()`, `retrieve()`, `update()`, `partial_update()`, `destroy()` and `list()` actions. """ + pass diff --git a/runtests.py b/runtests.py index 4dc475375..a5c30e11a 100755 --- a/runtests.py +++ b/runtests.py @@ -6,16 +6,23 @@ import sys import pytest -PYTEST_ARGS = { - 'default': [], - 'fast': ['-q'], -} -FLAKE8_ARGS = ['rest_framework', 'tests'] +PYTEST_ARGS = {"default": [], "fast": ["-q"]} -ISORT_ARGS = ['--recursive', '--check-only', '--diff', '-o' 'uritemplate', '-p', 'tests', 'rest_framework', 'tests'] +FLAKE8_ARGS = ["rest_framework", "tests"] -BLACK_ARGS = ['--check', '--verbose'] +ISORT_ARGS = [ + "--recursive", + "--check-only", + "--diff", + "-o" "uritemplate", + "-p", + "tests", + "rest_framework", + "tests", +] + +BLACK_ARGS = ["--check", "--verbose"] def exit_on_failure(ret, message=None): @@ -24,43 +31,48 @@ def exit_on_failure(ret, message=None): def flake8_main(args): - print('Running flake8 code linting') - ret = subprocess.call(['flake8'] + args) - print('flake8 failed' if ret else 'flake8 passed') + print("Running flake8 code linting") + ret = subprocess.call(["flake8"] + args) + print("flake8 failed" if ret else "flake8 passed") return ret def isort_main(args): - print('Running isort code checking') - ret = subprocess.call(['isort'] + args) + print("Running isort code checking") + ret = subprocess.call(["isort"] + args) if ret: - print('isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`') + print( + "isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`" + ) else: - print('isort passed') + print("isort passed") return ret def black_main(args): - print('Running black code checking') - ret = subprocess.call(['black', '.'] + args) + print("Running black code checking") + ret = subprocess.call(["black", "."] + args) if ret: - print('black failed: Some code have incorrectly formatted. Fix by running `black .`') + print( + "black failed: Some code have incorrectly formatted. Fix by running `black .`" + ) else: - print('black passed') + print("black passed") return ret + def split_class_and_function(string): - class_string, function_string = string.split('.', 1) + class_string, function_string = string.split(".", 1) return "%s and %s" % (class_string, function_string) def is_function(string): # `True` if it looks like a test function is included in the string. - return string.startswith('test_') or '.test_' in string + return string.startswith("test_") or ".test_" in string def is_class(string): @@ -70,7 +82,7 @@ def is_class(string): if __name__ == "__main__": try: - sys.argv.remove('--nolint') + sys.argv.remove("--nolint") except ValueError: run_black = True run_flake8 = True @@ -81,18 +93,18 @@ if __name__ == "__main__": run_isort = False try: - sys.argv.remove('--lintonly') + sys.argv.remove("--lintonly") except ValueError: run_tests = True else: run_tests = False try: - sys.argv.remove('--fast') + sys.argv.remove("--fast") except ValueError: - style = 'default' + style = "default" else: - style = 'fast' + style = "fast" run_black = False run_flake8 = False run_isort = False @@ -102,26 +114,23 @@ if __name__ == "__main__": first_arg = pytest_args[0] try: - pytest_args.remove('--coverage') + pytest_args.remove("--coverage") except ValueError: pass else: - pytest_args = [ - '--cov', '.', - '--cov-report', 'xml', - ] + pytest_args + pytest_args = ["--cov", ".", "--cov-report", "xml"] + pytest_args - if first_arg.startswith('-'): + if first_arg.startswith("-"): # `runtests.py [flags]` - pytest_args = ['tests'] + pytest_args + pytest_args = ["tests"] + pytest_args elif is_class(first_arg) and is_function(first_arg): # `runtests.py TestCase.test_function [flags]` expression = split_class_and_function(first_arg) - pytest_args = ['tests', '-k', expression] + pytest_args[1:] + pytest_args = ["tests", "-k", expression] + pytest_args[1:] elif is_class(first_arg) or is_function(first_arg): # `runtests.py TestCase [flags]` # `runtests.py test_function [flags]` - pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:] + pytest_args = ["tests", "-k", pytest_args[0]] + pytest_args[1:] else: pytest_args = PYTEST_ARGS[style] diff --git a/setup.cfg b/setup.cfg index c95134600..8056bfe93 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,16 +9,23 @@ addopts=--tb=short --strict -ra testspath = tests [flake8] -ignore = E501 +max-line-length = 120 +ignore = E501, W503, E203 banned-modules = json = use from rest_framework.utils import json! [isort] skip=.tox atomic=true -multi_line_output=5 -known_standard_library=types +multi_line_output=3 +lines_after_imports = 2 +black=types +combine_as_imports = true known_third_party=pytest,_pytest,django,pytz -known_first_party=rest_framework +known_first_party=rest_framework, tests +include_trailing_comma=true +line_length = 88 +balanced_wrapping = true +sections = FUTURE, STDLIB, DJANGO, CMS, THIRDPARTY, FIRSTPARTY, LIB, LOCALFOLDER [coverage:run] # NOTE: source is ignored with pytest-cov (but uses the same). diff --git a/setup.py b/setup.py index cb850a3ae..769f71322 100755 --- a/setup.py +++ b/setup.py @@ -10,21 +10,21 @@ from setuptools import find_packages, setup def read(f): - return open(f, 'r', encoding='utf-8').read() + return open(f, "r", encoding="utf-8").read() def get_version(package): """ Return package version as listed in `__version__` in `init.py`. """ - init_py = open(os.path.join(package, '__init__.py')).read() + init_py = open(os.path.join(package, "__init__.py")).read() return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1) -version = get_version('rest_framework') +version = get_version("rest_framework") -if sys.argv[-1] == 'publish': +if sys.argv[-1] == "publish": if os.system("pip freeze | grep twine"): print("twine not installed.\nUse `pip install twine`.\nExiting.") sys.exit() @@ -33,48 +33,48 @@ if sys.argv[-1] == 'publish': print("You probably want to also tag the version now:") print(" git tag -a %s -m 'version %s'" % (version, version)) print(" git push --tags") - shutil.rmtree('dist') - shutil.rmtree('build') - shutil.rmtree('djangorestframework.egg-info') + shutil.rmtree("dist") + shutil.rmtree("build") + shutil.rmtree("djangorestframework.egg-info") sys.exit() setup( - name='djangorestframework', + name="djangorestframework", version=version, - url='https://www.django-rest-framework.org/', - license='BSD', - description='Web APIs for Django, made easy.', - long_description=read('README.md'), - long_description_content_type='text/markdown', - author='Tom Christie', - author_email='tom@tomchristie.com', # SEE NOTE BELOW (*) - packages=find_packages(exclude=['tests*']), + url="https://www.django-rest-framework.org/", + license="BSD", + description="Web APIs for Django, made easy.", + long_description=read("README.md"), + long_description_content_type="text/markdown", + author="Tom Christie", + author_email="tom@tomchristie.com", # SEE NOTE BELOW (*) + packages=find_packages(exclude=["tests*"]), include_package_data=True, install_requires=[], python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", zip_safe=False, classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Environment :: Web Environment', - 'Framework :: Django', - 'Framework :: Django :: 1.11', - 'Framework :: Django :: 2.0', - 'Framework :: Django :: 2.1', - 'Framework :: Django :: 2.2', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Topic :: Internet :: WWW/HTTP', - ] + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Framework :: Django", + "Framework :: Django :: 1.11", + "Framework :: Django :: 2.0", + "Framework :: Django :: 2.1", + "Framework :: Django :: 2.2", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Topic :: Internet :: WWW/HTTP", + ], ) # (*) Please direct queries to the discussion group, rather than to me directly diff --git a/tests/authentication/migrations/0001_initial.py b/tests/authentication/migrations/0001_initial.py index cfc887240..774b23316 100644 --- a/tests/authentication/migrations/0001_initial.py +++ b/tests/authentication/migrations/0001_initial.py @@ -9,16 +9,22 @@ class Migration(migrations.Migration): initial = True - dependencies = [ - migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ] + dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)] operations = [ migrations.CreateModel( - name='CustomToken', + name="CustomToken", fields=[ - ('key', models.CharField(max_length=40, primary_key=True, serialize=False)), - ('user', models.OneToOneField(on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL)), + ( + "key", + models.CharField(max_length=40, primary_key=True, serialize=False), + ), + ( + "user", + models.OneToOneField( + on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL + ), + ), ], - ), + ) ] diff --git a/tests/authentication/test_authentication.py b/tests/authentication/test_authentication.py index 793773542..dfbfdf0e1 100644 --- a/tests/authentication/test_authentication.py +++ b/tests/authentication/test_authentication.py @@ -13,11 +13,18 @@ from django.test import TestCase, override_settings from django.utils import six from rest_framework import ( - HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status + HTTP_HEADER_ENCODING, + exceptions, + permissions, + renderers, + status, ) from rest_framework.authentication import ( - BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, - SessionAuthentication, TokenAuthentication + BaseAuthentication, + BasicAuthentication, + RemoteUserAuthentication, + SessionAuthentication, + TokenAuthentication, ) from rest_framework.authtoken.models import Token from rest_framework.authtoken.views import obtain_auth_token @@ -27,6 +34,7 @@ from rest_framework.views import APIView from .models import CustomToken + factory = APIRequestFactory() @@ -35,92 +43,77 @@ class CustomTokenAuthentication(TokenAuthentication): class CustomKeywordTokenAuthentication(TokenAuthentication): - keyword = 'Bearer' + keyword = "Bearer" class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) def get(self, request): - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + return HttpResponse({"a": 1, "b": 2, "c": 3}) def post(self, request): - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + return HttpResponse({"a": 1, "b": 2, "c": 3}) def put(self, request): - return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + return HttpResponse({"a": 1, "b": 2, "c": 3}) urlpatterns = [ url( - r'^session/$', - MockView.as_view(authentication_classes=[SessionAuthentication]) + r"^session/$", MockView.as_view(authentication_classes=[SessionAuthentication]) + ), + url(r"^basic/$", MockView.as_view(authentication_classes=[BasicAuthentication])), + url( + r"^remote-user/$", + MockView.as_view(authentication_classes=[RemoteUserAuthentication]), + ), + url(r"^token/$", MockView.as_view(authentication_classes=[TokenAuthentication])), + url( + r"^customtoken/$", + MockView.as_view(authentication_classes=[CustomTokenAuthentication]), ), url( - r'^basic/$', - MockView.as_view(authentication_classes=[BasicAuthentication]) + r"^customkeywordtoken/$", + MockView.as_view(authentication_classes=[CustomKeywordTokenAuthentication]), ), - url( - r'^remote-user/$', - MockView.as_view(authentication_classes=[RemoteUserAuthentication]) - ), - url( - r'^token/$', - MockView.as_view(authentication_classes=[TokenAuthentication]) - ), - url( - r'^customtoken/$', - MockView.as_view(authentication_classes=[CustomTokenAuthentication]) - ), - url( - r'^customkeywordtoken/$', - MockView.as_view( - authentication_classes=[CustomKeywordTokenAuthentication] - ) - ), - url(r'^auth-token/$', obtain_auth_token), - url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), + url(r"^auth-token/$", obtain_auth_token), + url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")), ] @override_settings(ROOT_URLCONF=__name__) class BasicAuthTests(TestCase): """Basic authentication""" + def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, self.email, self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) def test_post_form_passing_basic_auth(self): """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" - credentials = ('%s:%s' % (self.username, self.password)) + credentials = "%s:%s" % (self.username, self.password) base64_credentials = base64.b64encode( credentials.encode(HTTP_HEADER_ENCODING) ).decode(HTTP_HEADER_ENCODING) - auth = 'Basic %s' % base64_credentials + auth = "Basic %s" % base64_credentials response = self.csrf_client.post( - '/basic/', - {'example': 'example'}, - HTTP_AUTHORIZATION=auth + "/basic/", {"example": "example"}, HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_200_OK def test_post_json_passing_basic_auth(self): """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" - credentials = ('%s:%s' % (self.username, self.password)) + credentials = "%s:%s" % (self.username, self.password) base64_credentials = base64.b64encode( credentials.encode(HTTP_HEADER_ENCODING) ).decode(HTTP_HEADER_ENCODING) - auth = 'Basic %s' % base64_credentials + auth = "Basic %s" % base64_credentials response = self.csrf_client.post( - '/basic/', - {'example': 'example'}, - format='json', - HTTP_AUTHORIZATION=auth + "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_200_OK @@ -128,39 +121,34 @@ class BasicAuthTests(TestCase): """Ensure POSTing JSON over basic auth with incorrectly padded Base64 string is handled correctly""" # regression test for issue in 'rest_framework.authentication.BasicAuthentication.authenticate' # https://github.com/encode/django-rest-framework/issues/4089 - auth = 'Basic =a=' + auth = "Basic =a=" response = self.csrf_client.post( - '/basic/', - {'example': 'example'}, - format='json', - HTTP_AUTHORIZATION=auth + "/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_post_form_failing_basic_auth(self): """Ensure POSTing form over basic auth without correct credentials fails""" - response = self.csrf_client.post('/basic/', {'example': 'example'}) + response = self.csrf_client.post("/basic/", {"example": "example"}) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" response = self.csrf_client.post( - '/basic/', - {'example': 'example'}, - format='json' + "/basic/", {"example": "example"}, format="json" ) assert response.status_code == status.HTTP_401_UNAUTHORIZED - assert response['WWW-Authenticate'] == 'Basic realm="api"' + assert response["WWW-Authenticate"] == 'Basic realm="api"' def test_fail_post_if_credentials_are_missing(self): response = self.csrf_client.post( - '/basic/', {'example': 'example'}, HTTP_AUTHORIZATION='Basic ') + "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic " + ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_fail_post_if_credentials_contain_spaces(self): response = self.csrf_client.post( - '/basic/', {'example': 'example'}, - HTTP_AUTHORIZATION='Basic foo bar' + "/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic foo bar" ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -168,15 +156,14 @@ class BasicAuthTests(TestCase): @override_settings(ROOT_URLCONF=__name__) class SessionAuthTests(TestCase): """User session authentication""" + def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True) self.non_csrf_client = APIClient(enforce_csrf_checks=False) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, self.email, self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) def tearDown(self): self.csrf_client.logout() @@ -187,8 +174,8 @@ class SessionAuthTests(TestCase): cf. [#1810](https://github.com/encode/django-rest-framework/pull/1810) """ - response = self.csrf_client.get('/auth/login/') - content = response.content.decode('utf8') + response = self.csrf_client.get("/auth/login/") + content = response.content.decode("utf8") assert '' in content def test_post_form_session_auth_failing_csrf(self): @@ -196,7 +183,7 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication without CSRF token fails. """ self.csrf_client.login(username=self.username, password=self.password) - response = self.csrf_client.post('/session/', {'example': 'example'}) + response = self.csrf_client.post("/session/", {"example": "example"}) assert response.status_code == status.HTTP_403_FORBIDDEN def test_post_form_session_auth_passing_csrf(self): @@ -213,10 +200,9 @@ class SessionAuthTests(TestCase): self.csrf_client.cookies[settings.CSRF_COOKIE_NAME] = token # Post the token matching the cookie value - response = self.csrf_client.post('/session/', { - 'example': 'example', - 'csrfmiddlewaretoken': token, - }) + response = self.csrf_client.post( + "/session/", {"example": "example", "csrfmiddlewaretoken": token} + ) assert response.status_code == status.HTTP_200_OK def test_post_form_session_auth_passing(self): @@ -224,12 +210,8 @@ class SessionAuthTests(TestCase): Ensure POSTing form over session authentication with logged in user and CSRF token passes. """ - self.non_csrf_client.login( - username=self.username, password=self.password - ) - response = self.non_csrf_client.post( - '/session/', {'example': 'example'} - ) + self.non_csrf_client.login(username=self.username, password=self.password) + response = self.non_csrf_client.post("/session/", {"example": "example"}) assert response.status_code == status.HTTP_200_OK def test_put_form_session_auth_passing(self): @@ -237,38 +219,33 @@ class SessionAuthTests(TestCase): Ensure PUTting form over session authentication with logged in user and CSRF token passes. """ - self.non_csrf_client.login( - username=self.username, password=self.password - ) - response = self.non_csrf_client.put( - '/session/', {'example': 'example'} - ) + self.non_csrf_client.login(username=self.username, password=self.password) + response = self.non_csrf_client.put("/session/", {"example": "example"}) assert response.status_code == status.HTTP_200_OK def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. """ - response = self.csrf_client.post('/session/', {'example': 'example'}) + response = self.csrf_client.post("/session/", {"example": "example"}) assert response.status_code == status.HTTP_403_FORBIDDEN class BaseTokenAuthTests(object): """Token authentication""" + model = None path = None - header_prefix = 'Token ' + header_prefix = "Token " def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, self.email, self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) - self.key = 'abcd1234' + self.key = "abcd1234" self.token = self.model.objects.create(key=self.key, user=self.user) def test_post_form_passing_token_auth(self): @@ -278,39 +255,41 @@ class BaseTokenAuthTests(object): """ auth = self.header_prefix + self.key response = self.csrf_client.post( - self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth + self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_200_OK def test_fail_authentication_if_user_is_not_active(self): - user = User.objects.create_user('foo', 'bar', 'baz') + user = User.objects.create_user("foo", "bar", "baz") user.is_active = False user.save() - self.model.objects.create(key='foobar_token', user=user) + self.model.objects.create(key="foobar_token", user=user) response = self.csrf_client.post( - self.path, {'example': 'example'}, - HTTP_AUTHORIZATION=self.header_prefix + 'foobar_token' + self.path, + {"example": "example"}, + HTTP_AUTHORIZATION=self.header_prefix + "foobar_token", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_fail_post_form_passing_nonexistent_token_auth(self): # use a nonexistent token key - auth = self.header_prefix + 'wxyz6789' + auth = self.header_prefix + "wxyz6789" response = self.csrf_client.post( - self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth + self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_fail_post_if_token_is_missing(self): response = self.csrf_client.post( - self.path, {'example': 'example'}, - HTTP_AUTHORIZATION=self.header_prefix) + self.path, {"example": "example"}, HTTP_AUTHORIZATION=self.header_prefix + ) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_fail_post_if_token_contains_spaces(self): response = self.csrf_client.post( - self.path, {'example': 'example'}, - HTTP_AUTHORIZATION=self.header_prefix + 'foo bar' + self.path, + {"example": "example"}, + HTTP_AUTHORIZATION=self.header_prefix + "foo bar", ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -318,7 +297,7 @@ class BaseTokenAuthTests(object): # add an 'invalid' unicode character auth = self.header_prefix + self.key + "¸" response = self.csrf_client.post( - self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth + self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -329,8 +308,7 @@ class BaseTokenAuthTests(object): """ auth = self.header_prefix + self.key response = self.csrf_client.post( - self.path, {'example': 'example'}, - format='json', HTTP_AUTHORIZATION=auth + self.path, {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth ) assert response.status_code == status.HTTP_200_OK @@ -343,8 +321,10 @@ class BaseTokenAuthTests(object): def func_to_test(): return self.csrf_client.post( - self.path, {'example': 'example'}, - format='json', HTTP_AUTHORIZATION=auth + self.path, + {"example": "example"}, + format="json", + HTTP_AUTHORIZATION=auth, ) self.assertNumQueries(1, func_to_test) @@ -353,7 +333,7 @@ class BaseTokenAuthTests(object): """ Ensure POSTing form over token auth without correct credentials fails """ - response = self.csrf_client.post(self.path, {'example': 'example'}) + response = self.csrf_client.post(self.path, {"example": "example"}) assert response.status_code == status.HTTP_401_UNAUTHORIZED def test_post_json_failing_token_auth(self): @@ -361,7 +341,7 @@ class BaseTokenAuthTests(object): Ensure POSTing json over token auth without correct credentials fails """ response = self.csrf_client.post( - self.path, {'example': 'example'}, format='json' + self.path, {"example": "example"}, format="json" ) assert response.status_code == status.HTTP_401_UNAUTHORIZED @@ -369,7 +349,7 @@ class BaseTokenAuthTests(object): @override_settings(ROOT_URLCONF=__name__) class TokenAuthTests(BaseTokenAuthTests, TestCase): model = Token - path = '/token/' + path = "/token/" def test_token_has_auto_assigned_key_if_none_provided(self): """Ensure creating a token with no key will auto-assign a key""" @@ -387,12 +367,12 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase): """Ensure token login view using JSON POST works.""" client = APIClient(enforce_csrf_checks=True) response = client.post( - '/auth-token/', - {'username': self.username, 'password': self.password}, - format='json' + "/auth-token/", + {"username": self.username, "password": self.password}, + format="json", ) assert response.status_code == status.HTTP_200_OK - assert response.data['token'] == self.key + assert response.data["token"] == self.key def test_token_login_json_bad_creds(self): """ @@ -401,41 +381,41 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase): """ client = APIClient(enforce_csrf_checks=True) response = client.post( - '/auth-token/', - {'username': self.username, 'password': "badpass"}, - format='json' + "/auth-token/", + {"username": self.username, "password": "badpass"}, + format="json", ) assert response.status_code == 400 def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" client = APIClient(enforce_csrf_checks=True) - response = client.post('/auth-token/', - {'username': self.username}, format='json') + response = client.post( + "/auth-token/", {"username": self.username}, format="json" + ) assert response.status_code == 400 def test_token_login_form(self): """Ensure token login view using form POST works.""" client = APIClient(enforce_csrf_checks=True) response = client.post( - '/auth-token/', - {'username': self.username, 'password': self.password} + "/auth-token/", {"username": self.username, "password": self.password} ) assert response.status_code == status.HTTP_200_OK - assert response.data['token'] == self.key + assert response.data["token"] == self.key @override_settings(ROOT_URLCONF=__name__) class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): model = CustomToken - path = '/customtoken/' + path = "/customtoken/" @override_settings(ROOT_URLCONF=__name__) class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase): model = Token - path = '/customkeywordtoken/' - header_prefix = 'Bearer ' + path = "/customkeywordtoken/" + header_prefix = "Bearer " class IncorrectCredentialsTests(TestCase): @@ -445,42 +425,42 @@ class IncorrectCredentialsTests(TestCase): authentication should run and error, even if no permissions are set on the view. """ + class IncorrectCredentialsAuth(BaseAuthentication): def authenticate(self, request): - raise exceptions.AuthenticationFailed('Bad credentials') + raise exceptions.AuthenticationFailed("Bad credentials") - request = factory.get('/') + request = factory.get("/") view = MockView.as_view( - authentication_classes=(IncorrectCredentialsAuth,), - permission_classes=() + authentication_classes=(IncorrectCredentialsAuth,), permission_classes=() ) response = view(request) assert response.status_code == status.HTTP_403_FORBIDDEN - assert response.data == {'detail': 'Bad credentials'} + assert response.data == {"detail": "Bad credentials"} class FailingAuthAccessedInRenderer(TestCase): def setUp(self): class AuthAccessingRenderer(renderers.BaseRenderer): - media_type = 'text/plain' - format = 'txt' + media_type = "text/plain" + format = "txt" def render(self, data, media_type=None, renderer_context=None): - request = renderer_context['request'] + request = renderer_context["request"] if request.user.is_authenticated: - return b'authenticated' - return b'not authenticated' + return b"authenticated" + return b"not authenticated" class FailingAuth(BaseAuthentication): def authenticate(self, request): - raise exceptions.AuthenticationFailed('authentication failed') + raise exceptions.AuthenticationFailed("authentication failed") class ExampleView(APIView): authentication_classes = (FailingAuth,) renderer_classes = (AuthAccessingRenderer,) def get(self, request): - return Response({'foo': 'bar'}) + return Response({"foo": "bar"}) self.view = ExampleView.as_view() @@ -490,10 +470,10 @@ class FailingAuthAccessedInRenderer(TestCase): `request.user` without raising an exception. Particularly relevant to HTML responses that might reasonably access `request.user`. """ - request = factory.get('/') + request = factory.get("/") response = self.view(request) content = response.render().content - assert content == b'not authenticated' + assert content == b"not authenticated" class NoAuthenticationClassesTests(TestCase): @@ -505,23 +485,21 @@ class NoAuthenticationClassesTests(TestCase): """ class DummyPermission(permissions.BasePermission): - message = 'Dummy permission message' + message = "Dummy permission message" def has_permission(self, request, view): return False - request = factory.get('/') + request = factory.get("/") view = MockView.as_view( - authentication_classes=(), - permission_classes=(DummyPermission,), + authentication_classes=(), permission_classes=(DummyPermission,) ) response = view(request) assert response.status_code == status.HTTP_403_FORBIDDEN - assert response.data == {'detail': 'Dummy permission message'} + assert response.data == {"detail": "Dummy permission message"} class BasicAuthenticationUnitTests(TestCase): - def test_base_authentication_abstract_method(self): with pytest.raises(NotImplementedError): BaseAuthentication().authenticate({}) @@ -529,34 +507,34 @@ class BasicAuthenticationUnitTests(TestCase): def test_basic_authentication_raises_error_if_user_not_found(self): auth = BasicAuthentication() with pytest.raises(exceptions.AuthenticationFailed): - auth.authenticate_credentials('invalid id', 'invalid password') + auth.authenticate_credentials("invalid id", "invalid password") def test_basic_authentication_raises_error_if_user_not_active(self): from rest_framework import authentication class MockUser(object): is_active = False + old_authenticate = authentication.authenticate authentication.authenticate = lambda **kwargs: MockUser() auth = authentication.BasicAuthentication() with pytest.raises(exceptions.AuthenticationFailed) as error: - auth.authenticate_credentials('foo', 'bar') - assert 'User inactive or deleted.' in str(error) + auth.authenticate_credentials("foo", "bar") + assert "User inactive or deleted." in str(error) authentication.authenticate = old_authenticate -@override_settings(ROOT_URLCONF=__name__, - AUTHENTICATION_BACKENDS=('django.contrib.auth.backends.RemoteUserBackend',)) +@override_settings( + ROOT_URLCONF=__name__, + AUTHENTICATION_BACKENDS=("django.contrib.auth.backends.RemoteUserBackend",), +) class RemoteUserAuthenticationUnitTests(TestCase): def setUp(self): - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, self.email, self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) def test_remote_user_works(self): - response = self.client.post('/remote-user/', - REMOTE_USER=self.username) + response = self.client.post("/remote-user/", REMOTE_USER=self.username) self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/tests/browsable_api/auth_urls.py b/tests/browsable_api/auth_urls.py index 0e9379717..8dd78aef1 100644 --- a/tests/browsable_api/auth_urls.py +++ b/tests/browsable_api/auth_urls.py @@ -4,7 +4,8 @@ from django.conf.urls import include, url from .views import MockView + urlpatterns = [ - url(r'^$', MockView.as_view()), - url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), + url(r"^$", MockView.as_view()), + url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")), ] diff --git a/tests/browsable_api/no_auth_urls.py b/tests/browsable_api/no_auth_urls.py index 5fc95c727..505e9a762 100644 --- a/tests/browsable_api/no_auth_urls.py +++ b/tests/browsable_api/no_auth_urls.py @@ -4,6 +4,5 @@ from django.conf.urls import url from .views import MockView -urlpatterns = [ - url(r'^$', MockView.as_view()), -] + +urlpatterns = [url(r"^$", MockView.as_view())] diff --git a/tests/browsable_api/test_browsable_api.py b/tests/browsable_api/test_browsable_api.py index 684d7ae14..72e2fd1de 100644 --- a/tests/browsable_api/test_browsable_api.py +++ b/tests/browsable_api/test_browsable_api.py @@ -6,71 +6,65 @@ from django.test import TestCase, override_settings from rest_framework.test import APIClient -@override_settings(ROOT_URLCONF='tests.browsable_api.auth_urls') +@override_settings(ROOT_URLCONF="tests.browsable_api.auth_urls") class DropdownWithAuthTests(TestCase): """Tests correct dropdown behaviour with Auth views enabled.""" + def setUp(self): self.client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, - self.email, - self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) def tearDown(self): self.client.logout() def test_name_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) - response = self.client.get('/') - content = response.content.decode('utf8') - assert 'john' in content + response = self.client.get("/") + content = response.content.decode("utf8") + assert "john" in content def test_logout_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) - response = self.client.get('/') - content = response.content.decode('utf8') - assert '>Log out<' in content + response = self.client.get("/") + content = response.content.decode("utf8") + assert ">Log out<" in content def test_login_shown_when_logged_out(self): - response = self.client.get('/') - content = response.content.decode('utf8') - assert '>Log in<' in content + response = self.client.get("/") + content = response.content.decode("utf8") + assert ">Log in<" in content -@override_settings(ROOT_URLCONF='tests.browsable_api.no_auth_urls') +@override_settings(ROOT_URLCONF="tests.browsable_api.no_auth_urls") class NoDropdownWithoutAuthTests(TestCase): """Tests correct dropdown behaviour with Auth views NOT enabled.""" + def setUp(self): self.client = APIClient(enforce_csrf_checks=True) - self.username = 'john' - self.email = 'lennon@thebeatles.com' - self.password = 'password' - self.user = User.objects.create_user( - self.username, - self.email, - self.password - ) + self.username = "john" + self.email = "lennon@thebeatles.com" + self.password = "password" + self.user = User.objects.create_user(self.username, self.email, self.password) def tearDown(self): self.client.logout() def test_name_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) - response = self.client.get('/') - content = response.content.decode('utf8') - assert 'john' in content + response = self.client.get("/") + content = response.content.decode("utf8") + assert "john" in content def test_dropdown_not_shown_when_logged_in(self): self.client.login(username=self.username, password=self.password) - response = self.client.get('/') - content = response.content.decode('utf8') + response = self.client.get("/") + content = response.content.decode("utf8") assert '