Merge remote-tracking branch 'origin/master' into feature/generateschema-version-arg-support

This commit is contained in:
Cihan Eran 2022-09-21 16:48:20 +03:00
commit b005fd241a
No known key found for this signature in database
GPG Key ID: 8AB0141F0F79EBFE
38 changed files with 329 additions and 275 deletions

View File

@ -1,7 +1,7 @@
include README.md include README.md
include LICENSE.md include LICENSE.md
recursive-include tests/ * recursive-include tests/ *
recursive-include rest_framework/static *.js *.css *.png *.ico *.eot *.svg *.ttf *.woff *.woff2 recursive-include rest_framework/static *.js *.css *.map *.png *.ico *.eot *.svg *.ttf *.woff *.woff2
recursive-include rest_framework/templates *.html schema.js recursive-include rest_framework/templates *.html schema.js
recursive-include rest_framework/locale *.mo recursive-include rest_framework/locale *.mo
global-exclude __pycache__ global-exclude __pycache__

View File

@ -54,8 +54,8 @@ There is a live example API for testing purposes, [available here][sandbox].
# Requirements # Requirements
* Python (3.6, 3.7, 3.8, 3.9, 3.10) * Python 3.6+
* Django (2.2, 3.0, 3.1, 3.2, 4.0, 4.1) * Django 4.1, 4.0, 3.2, 3.1, 3.0
We **highly recommend** and only officially support the latest patch release of We **highly recommend** and only officially support the latest patch release of
each Python and Django series. each Python and Django series.
@ -90,9 +90,10 @@ Startup up a new project like so...
Now edit the `example/urls.py` module in your project: Now edit the `example/urls.py` module in your project:
```python ```python
from django.urls import path, include
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import serializers, viewsets, routers from django.urls import include, path
from rest_framework import routers, serializers, viewsets
# Serializers define the API representation. # Serializers define the API representation.
class UserSerializer(serializers.HyperlinkedModelSerializer): class UserSerializer(serializers.HyperlinkedModelSerializer):
@ -111,7 +112,6 @@ class UserViewSet(viewsets.ModelViewSet):
router = routers.DefaultRouter() router = routers.DefaultRouter()
router.register(r'users', UserViewSet) router.register(r'users', UserViewSet)
# Wire up our API using automatic URL routing. # Wire up our API using automatic URL routing.
# Additionally, we include login URLs for the browsable API. # Additionally, we include login URLs for the browsable API.
urlpatterns = [ urlpatterns = [
@ -185,7 +185,7 @@ Please see the [security policy][security-policy].
[codecov]: https://codecov.io/github/encode/django-rest-framework?branch=master [codecov]: https://codecov.io/github/encode/django-rest-framework?branch=master
[pypi-version]: https://img.shields.io/pypi/v/djangorestframework.svg [pypi-version]: https://img.shields.io/pypi/v/djangorestframework.svg
[pypi]: https://pypi.org/project/djangorestframework/ [pypi]: https://pypi.org/project/djangorestframework/
[twitter]: https://twitter.com/_tomchristie [twitter]: https://twitter.com/starletdreaming
[group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework [group]: https://groups.google.com/forum/?fromgroups#!forum/django-rest-framework
[sandbox]: https://restframework.herokuapp.com/ [sandbox]: https://restframework.herokuapp.com/

View File

@ -173,9 +173,9 @@ The `curl` command line tool may be useful for testing token authenticated APIs.
--- ---
#### Generating Tokens ### Generating Tokens
##### By using signals #### By using signals
If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal. If you want every user to have an automatically generated Token, you can simply catch the User's `post_save` signal.
@ -199,7 +199,7 @@ If you've already created some users, you can generate tokens for all existing u
for user in User.objects.all(): for user in User.objects.all():
Token.objects.get_or_create(user=user) Token.objects.get_or_create(user=user)
##### By exposing an api endpoint #### By exposing an api endpoint
When using `TokenAuthentication`, you may want to provide a mechanism for clients to obtain a token given the username and password. REST framework provides a built-in view to provide this behaviour. To use it, add the `obtain_auth_token` view to your URLconf: When using `TokenAuthentication`, you may want to provide a mechanism for clients to obtain a token given the username and password. REST framework provides a built-in view to provide this behaviour. To use it, add the `obtain_auth_token` view to your URLconf:
@ -248,7 +248,7 @@ And in your `urls.py`:
] ]
##### With Django admin #### With Django admin
It is also possible to create Tokens manually through the admin interface. In case you are using a large user base, we recommend that you monkey patch the `TokenAdmin` class customize it to your needs, more specifically by declaring the `user` field as `raw_field`. It is also possible to create Tokens manually through the admin interface. In case you are using a large user base, we recommend that you monkey patch the `TokenAdmin` class customize it to your needs, more specifically by declaring the `user` field as `raw_field`.
@ -369,7 +369,7 @@ The following third-party packages are also available.
The [Django OAuth Toolkit][django-oauth-toolkit] package provides OAuth 2.0 support and works with Python 3.4+. The package is maintained by [jazzband][jazzband] and uses the excellent [OAuthLib][oauthlib]. The package is well documented, and well supported and is currently our **recommended package for OAuth 2.0 support**. The [Django OAuth Toolkit][django-oauth-toolkit] package provides OAuth 2.0 support and works with Python 3.4+. The package is maintained by [jazzband][jazzband] and uses the excellent [OAuthLib][oauthlib]. The package is well documented, and well supported and is currently our **recommended package for OAuth 2.0 support**.
#### Installation & configuration ### Installation & configuration
Install using `pip`. Install using `pip`.
@ -396,7 +396,7 @@ The [Django REST framework OAuth][django-rest-framework-oauth] package provides
This package was previously included directly in the REST framework but is now supported and maintained as a third-party package. This package was previously included directly in the REST framework but is now supported and maintained as a third-party package.
#### Installation & configuration ### Installation & configuration
Install the package using `pip`. Install the package using `pip`.

View File

@ -159,14 +159,6 @@ Corresponds to `django.db.models.fields.BooleanField`.
**Signature:** `BooleanField()` **Signature:** `BooleanField()`
## NullBooleanField
A boolean representation that also accepts `None` as a valid value.
Corresponds to `django.db.models.fields.NullBooleanField`.
**Signature:** `NullBooleanField()`
--- ---
# String fields # String fields

View File

@ -65,7 +65,7 @@ The following attributes control the basic view behavior.
* `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method. If you are overriding a view method, it is important that you call `get_queryset()` instead of accessing this property directly, as `queryset` will get evaluated once, and those results will be cached for all subsequent requests. * `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method. If you are overriding a view method, it is important that you call `get_queryset()` instead of accessing this property directly, as `queryset` will get evaluated once, and those results will be cached for all subsequent requests.
* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method. * `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method.
* `lookup_field` - The model field that should be used to for performing object lookup of individual model instances. Defaults to `'pk'`. Note that when using hyperlinked APIs you'll need to ensure that *both* the API views *and* the serializer classes set the lookup fields if you need to use a custom value. * `lookup_field` - The model field that should be used for performing object lookup of individual model instances. Defaults to `'pk'`. Note that when using hyperlinked APIs you'll need to ensure that *both* the API views *and* the serializer classes set the lookup fields if you need to use a custom value.
* `lookup_url_kwarg` - The URL keyword argument that should be used for object lookup. The URL conf should include a keyword argument corresponding to this value. If unset this defaults to using the same value as `lookup_field`. * `lookup_url_kwarg` - The URL keyword argument that should be used for object lookup. The URL conf should include a keyword argument corresponding to this value. If unset this defaults to using the same value as `lookup_field`.
**Pagination**: **Pagination**:
@ -217,7 +217,7 @@ If the request data provided for creating the object was invalid, a `400 Bad Req
Provides a `.retrieve(request, *args, **kwargs)` method, that implements returning an existing model instance in a response. Provides a `.retrieve(request, *args, **kwargs)` method, that implements returning an existing model instance in a response.
If an object can be retrieved this returns a `200 OK` response, with a serialized representation of the object as the body of the response. Otherwise it will return a `404 Not Found`. If an object can be retrieved this returns a `200 OK` response, with a serialized representation of the object as the body of the response. Otherwise, it will return a `404 Not Found`.
## UpdateModelMixin ## UpdateModelMixin
@ -335,7 +335,7 @@ For example, if you need to lookup objects based on multiple fields in the URL c
queryset = self.filter_queryset(queryset) # Apply any filter backends queryset = self.filter_queryset(queryset) # Apply any filter backends
filter = {} filter = {}
for field in self.lookup_fields: for field in self.lookup_fields:
if self.kwargs[field]: # Ignore empty fields. if self.kwargs.get(field): # Ignore empty fields.
filter[field] = self.kwargs[field] filter[field] = self.kwargs[field]
obj = get_object_or_404(queryset, **filter) # Lookup the object obj = get_object_or_404(queryset, **filter) # Lookup the object
self.check_object_permissions(self.request, obj) self.check_object_permissions(self.request, obj)

View File

@ -171,7 +171,7 @@ This permission is suitable if you want to your API to allow read permissions to
## DjangoModelPermissions ## DjangoModelPermissions
This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. This permission must only be applied to views that have a `.queryset` property or `get_queryset()` method. Authorization will only be granted if the user *is authenticated* and has the *relevant model permissions* assigned. This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. This permission must only be applied to views that have a `.queryset` property or `get_queryset()` method. Authorization will only be granted if the user *is authenticated* and has the *relevant model permissions* assigned. The appropriate model is determined by checking `get_queryset().model` or `queryset.model`.
* `POST` requests require the user to have the `add` permission on the model. * `POST` requests require the user to have the `add` permission on the model.
* `PUT` and `PATCH` requests require the user to have the `change` permission on the model. * `PUT` and `PATCH` requests require the user to have the `change` permission on the model.

View File

@ -602,7 +602,7 @@ A mapping of Django model fields to REST framework serializer fields. You can ov
This property should be the serializer field class, that is used for relational fields by default. This property should be the serializer field class, that is used for relational fields by default.
For `ModelSerializer` this defaults to `PrimaryKeyRelatedField`. For `ModelSerializer` this defaults to `serializers.PrimaryKeyRelatedField`.
For `HyperlinkedModelSerializer` this defaults to `serializers.HyperlinkedRelatedField`. For `HyperlinkedModelSerializer` this defaults to `serializers.HyperlinkedRelatedField`.
@ -886,7 +886,7 @@ Because this class provides the same interface as the `Serializer` class, you ca
The only difference you'll notice when doing so is the `BaseSerializer` classes will not generate HTML forms in the browsable API. This is because the data they return does not include all the field information that would allow each field to be rendered into a suitable HTML input. The only difference you'll notice when doing so is the `BaseSerializer` classes will not generate HTML forms in the browsable API. This is because the data they return does not include all the field information that would allow each field to be rendered into a suitable HTML input.
##### Read-only `BaseSerializer` classes #### Read-only `BaseSerializer` classes
To implement a read-only serializer using the `BaseSerializer` class, we just need to override the `.to_representation()` method. Let's take a look at an example using a simple Django model: To implement a read-only serializer using the `BaseSerializer` class, we just need to override the `.to_representation()` method. Let's take a look at an example using a simple Django model:
@ -920,7 +920,7 @@ Or use it to serialize multiple instances:
serializer = HighScoreSerializer(queryset, many=True) serializer = HighScoreSerializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)
##### Read-write `BaseSerializer` classes #### Read-write `BaseSerializer` classes
To create a read-write serializer we first need to implement a `.to_internal_value()` method. This method returns the validated values that will be used to construct the object instance, and may raise a `serializers.ValidationError` if the supplied data is in an incorrect format. To create a read-write serializer we first need to implement a `.to_internal_value()` method. This method returns the validated values that will be used to construct the object instance, and may raise a `serializers.ValidationError` if the supplied data is in an incorrect format.
@ -969,7 +969,7 @@ Here's a complete example of our previous `HighScoreSerializer`, that's been upd
The `BaseSerializer` class is also useful if you want to implement new generic serializer classes for dealing with particular serialization styles, or for integrating with alternative storage backends. The `BaseSerializer` class is also useful if you want to implement new generic serializer classes for dealing with particular serialization styles, or for integrating with alternative storage backends.
The following class is an example of a generic serializer that can handle coercing arbitrary objects into primitive representations. The following class is an example of a generic serializer that can handle coercing arbitrary complex objects into primitive representations.
class ObjectSerializer(serializers.BaseSerializer): class ObjectSerializer(serializers.BaseSerializer):
""" """

View File

@ -0,0 +1,62 @@
<style>
.promo li a {
float: left;
width: 130px;
height: 20px;
text-align: center;
margin: 10px 30px;
padding: 150px 0 0 0;
background-position: 0 50%;
background-size: 130px auto;
background-repeat: no-repeat;
font-size: 120%;
color: black;
}
.promo li {
list-style: none;
}
</style>
# Django REST framework 3.14
## Django 4.1 support
The latest release now fully supports Django 4.1.
Our requirements are now:
* Python 3.6+
* Django 4.1, 4.0, 3.2, 3.1, 3.0
## `raise_exceptions` argument for `is_valid` is now keyword-only.
Calling `serializer_instance.is_valid(True)` is no longer acceptable syntax.
If you'd like to use the `raise_exceptions` argument, you must use it as a
keyword argument.
See Pull Request [#7952](https://github.com/encode/django-rest-framework/pull/7952) for more details.
## `ManyRelatedField` supports returning the default when the source attribute doesn't exist.
Previously, if you used a serializer field with `many=True` with a dot notated source field
that didn't exist, it would raise an `AttributeError`. Now it will return the default or be
skipped depending on the other arguments.
See Pull Request [#7574](https://github.com/encode/django-rest-framework/pull/7574) for more details.
## Make Open API `get_reference` public.
Returns a reference to the serializer component. This may be useful if you override `get_schema()`.
## Change semantic of OR of two permission classes.
When OR-ing two permissions, the request has to pass either class's `has_permission() and has_object_permission()`.
Previously, both class's `has_permission()` was ignored when OR-ing two permissions together.
See Pull Request [#7522](https://github.com/encode/django-rest-framework/pull/7522) for more details.
## Minor fixes and improvements
There are a number of minor fixes and improvements in this release. See the [release notes](release-notes.md) page for a complete listing.

View File

@ -34,6 +34,23 @@ You can determine your currently installed version using `pip show`:
--- ---
## 3.14.x series
### 3.14.0
Date: 22nd September 2022
* Enforce `is_valid(raise_exception=False)` as a keyword-only argument. [[#7952](https://github.com/encode/django-rest-framework/pull/7952)]
* Django 4.1 compatability. [[#8591](https://github.com/encode/django-rest-framework/pull/8591)]
* Stop calling `set_context` on Validators. [[#8589](https://github.com/encode/django-rest-framework/pull/8589)]
* Return `NotImplemented` from `ErrorDetails.__ne__`. [[#8538](https://github.com/encode/django-rest-framework/pull/8538)]
* Don't evaluate `DateTimeField.default_timezone` when a custom timezone is set. [[#8531](https://github.com/encode/django-rest-framework/pull/8531)]
* Make relative URLs clickable in Browseable API. [[#8464](https://github.com/encode/django-rest-framework/pull/8464)]
* Support `ManyRelatedField` falling back to the default value when the attribute specified by dot notation doesn't exist. Matches `ManyRelatedField.get_attribute` to `Field.get_attribute`. [[#7574](https://github.com/encode/django-rest-framework/pull/7574)]
* Make `schemas.openapi.get_reference` public. [[#7515](https://github.com/encode/django-rest-framework/pull/7515)]
* Make `ReturnDict` support `dict` union operators on Python 3.9 and later. [[#8302](https://github.com/encode/django-rest-framework/pull/8302)]
* Update throttling to check if `request.user` is set before checking if the user is authenticated. [[#8370](https://github.com/encode/django-rest-framework/pull/8370)]
## 3.13.x series ## 3.13.x series
### 3.13.1 ### 3.13.1

View File

@ -112,7 +112,7 @@ Now update the `snippets/urls.py` file slightly, to append a set of `format_suff
urlpatterns = [ urlpatterns = [
path('snippets/', views.snippet_list), path('snippets/', views.snippet_list),
path('snippets/<int:pk>', views.snippet_detail), path('snippets/<int:pk>/', views.snippet_detail),
] ]
urlpatterns = format_suffix_patterns(urlpatterns) urlpatterns = format_suffix_patterns(urlpatterns)

View File

@ -10,7 +10,7 @@ ______ _____ _____ _____ __
import django import django
__title__ = 'Django REST framework' __title__ = 'Django REST framework'
__version__ = '3.13.1' __version__ = '3.14.0'
__author__ = 'Tom Christie' __author__ = 'Tom Christie'
__license__ = 'BSD 3-Clause' __license__ = 'BSD 3-Clause'
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd' __copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
@ -35,3 +35,7 @@ class RemovedInDRF313Warning(DeprecationWarning):
class RemovedInDRF314Warning(PendingDeprecationWarning): class RemovedInDRF314Warning(PendingDeprecationWarning):
pass pass
class RemovedInDRF315Warning(PendingDeprecationWarning):
pass

View File

@ -2,6 +2,7 @@
The `compat` module provides support for backwards compatibility with older The `compat` module provides support for backwards compatibility with older
versions of Django/Python, and compatibility wrappers around optional packages. versions of Django/Python, and compatibility wrappers around optional packages.
""" """
import django
from django.conf import settings from django.conf import settings
from django.views.generic import View from django.views.generic import View
@ -152,6 +153,30 @@ else:
return False return False
if django.VERSION >= (4, 2):
# Django 4.2+: use the stock parse_header_parameters function
# Note: Django 4.1 also has an implementation of parse_header_parameters
# which is slightly different from the one in 4.2, it needs
# the compatibility shim as well.
from django.utils.http import parse_header_parameters
else:
# Django <= 4.1: create a compatibility shim for parse_header_parameters
from django.http.multipartparser import parse_header
def parse_header_parameters(line):
# parse_header works with bytes, but parse_header_parameters
# works with strings. Call encode to convert the line to bytes.
main_value_pair, params = parse_header(line.encode())
return main_value_pair, {
# parse_header will convert *some* values to string.
# parse_header_parameters converts *all* values to string.
# Make sure all values are converted by calling decode on
# any remaining non-string values.
k: v if isinstance(v, str) else v.decode()
for k, v in params.items()
}
# `separators` argument to `json.dumps()` differs between 2.x and 3.x # `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767 # See: https://bugs.python.org/issue22767
SHORT_SEPARATORS = (',', ':') SHORT_SEPARATORS = (',', ':')

View File

@ -1,7 +1,7 @@
""" """
Handled exceptions raised by REST framework. Handled exceptions raised by REST framework.
In addition Django's built in 403 and 404 exceptions are handled. In addition, Django's built in 403 and 404 exceptions are handled.
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`) (`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
""" """
import math import math
@ -72,16 +72,19 @@ class ErrorDetail(str):
return self return self
def __eq__(self, other): def __eq__(self, other):
r = super().__eq__(other) result = super().__eq__(other)
if r is NotImplemented: if result is NotImplemented:
return NotImplemented return NotImplemented
try: try:
return r and self.code == other.code return result and self.code == other.code
except AttributeError: except AttributeError:
return r return result
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) result = self.__eq__(other)
if result is NotImplemented:
return NotImplemented
return not result
def __repr__(self): def __repr__(self):
return 'ErrorDetail(string=%r, code=%r)' % ( return 'ErrorDetail(string=%r, code=%r)' % (

View File

@ -5,7 +5,6 @@ import functools
import inspect import inspect
import re import re
import uuid import uuid
import warnings
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
@ -30,9 +29,7 @@ from django.utils.ipv6 import clean_ipv6_address
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from pytz.exceptions import InvalidTimeError from pytz.exceptions import InvalidTimeError
from rest_framework import ( from rest_framework import ISO_8601
ISO_8601, RemovedInDRF313Warning, RemovedInDRF314Warning
)
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, json, representation from rest_framework.utils import html, humanize_datetime, json, representation
@ -265,16 +262,6 @@ class CreateOnlyDefault:
if is_update: if is_update:
raise SkipField() raise SkipField()
if callable(self.default): if callable(self.default):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
RemovedInDRF313Warning, stacklevel=2
)
self.default.set_context(self)
if getattr(self.default, 'requires_context', False): if getattr(self.default, 'requires_context', False):
return self.default(serializer_field) return self.default(serializer_field)
else: else:
@ -504,16 +491,6 @@ class Field:
# No default, or this is a partial update. # No default, or this is a partial update.
raise SkipField() raise SkipField()
if callable(self.default): if callable(self.default):
if hasattr(self.default, 'set_context'):
warnings.warn(
"Method `set_context` on defaults is deprecated and will "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
RemovedInDRF313Warning, stacklevel=2
)
self.default.set_context(self)
if getattr(self.default, 'requires_context', False): if getattr(self.default, 'requires_context', False):
return self.default(self) return self.default(self)
else: else:
@ -578,16 +555,6 @@ class Field:
""" """
errors = [] errors = []
for validator in self.validators: for validator in self.validators:
if hasattr(validator, 'set_context'):
warnings.warn(
"Method `set_context` on validators is deprecated and will "
"no longer be called starting with 3.13. Instead set "
"`requires_context = True` on the class, and accept the "
"context as an additional argument.",
RemovedInDRF313Warning, stacklevel=2
)
validator.set_context(self)
try: try:
if getattr(validator, 'requires_context', False): if getattr(validator, 'requires_context', False):
validator(value, self) validator(value, self)
@ -744,23 +711,6 @@ class BooleanField(Field):
return bool(value) return bool(value)
class NullBooleanField(BooleanField):
initial = None
def __init__(self, **kwargs):
warnings.warn(
"The `NullBooleanField` is deprecated and will be removed starting "
"with 3.14. Instead use the `BooleanField` field and set "
"`allow_null=True` which does the same thing.",
RemovedInDRF314Warning, stacklevel=2
)
assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.'
kwargs['allow_null'] = True
super().__init__(**kwargs)
# String types... # String types...
class CharField(Field): class CharField(Field):

View File

@ -36,7 +36,6 @@ class SimpleMetadata(BaseMetadata):
label_lookup = ClassLookupDict({ label_lookup = ClassLookupDict({
serializers.Field: 'field', serializers.Field: 'field',
serializers.BooleanField: 'boolean', serializers.BooleanField: 'boolean',
serializers.NullBooleanField: 'boolean',
serializers.CharField: 'string', serializers.CharField: 'string',
serializers.UUIDField: 'string', serializers.UUIDField: 'string',
serializers.URLField: 'url', serializers.URLField: 'url',

View File

@ -4,7 +4,7 @@ incoming request. Typically this will be based on the request's Accept header.
""" """
from django.http import Http404 from django.http import Http404
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import exceptions
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import ( from rest_framework.utils.mediatypes import (
_MediaType, media_type_matches, order_by_precedence _MediaType, media_type_matches, order_by_precedence
@ -64,9 +64,11 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Accepted media type is 'application/json' # Accepted media type is 'application/json'
full_media_type = ';'.join( full_media_type = ';'.join(
(renderer.media_type,) + (renderer.media_type,) +
tuple('{}={}'.format( tuple(
key, value.decode(HTTP_HEADER_ENCODING)) '{}={}'.format(key, value)
for key, value in media_type_wrapper.params.items())) for key, value in media_type_wrapper.params.items()
)
)
return renderer, full_media_type return renderer, full_media_type
else: else:
# Eg client requests 'application/json; indent=8' # Eg client requests 'application/json; indent=8'

View File

@ -5,7 +5,6 @@ They give us a generic way of being able to handle various media types
on the request, such as form content or json encoded data. on the request, such as form content or json encoded data.
""" """
import codecs import codecs
from urllib import parse
from django.conf import settings from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers from django.core.files.uploadhandler import StopFutureHandlers
@ -13,10 +12,10 @@ from django.http import QueryDict
from django.http.multipartparser import ChunkIter from django.http.multipartparser import ChunkIter
from django.http.multipartparser import \ from django.http.multipartparser import \
MultiPartParser as DjangoMultiPartParser MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header from django.http.multipartparser import MultiPartParserError
from django.utils.encoding import force_str
from rest_framework import renderers from rest_framework import renderers
from rest_framework.compat import parse_header_parameters
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import json from rest_framework.utils import json
@ -201,23 +200,10 @@ class FileUploadParser(BaseParser):
try: try:
meta = parser_context['request'].META meta = parser_context['request'].META
disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode()) disposition, params = parse_header_parameters(meta['HTTP_CONTENT_DISPOSITION'])
filename_parm = disposition[1] if 'filename*' in params:
if 'filename*' in filename_parm: return params['filename*']
return self.get_encoded_filename(filename_parm) else:
return force_str(filename_parm['filename']) return params['filename']
except (AttributeError, KeyError, ValueError): except (AttributeError, KeyError, ValueError):
pass pass
def get_encoded_filename(self, filename_parm):
"""
Handle encoded filenames per RFC6266. See also:
https://tools.ietf.org/html/rfc2231#section-4
"""
encoded_filename = force_str(filename_parm['filename*'])
try:
charset, lang, filename = encoded_filename.split('\'', 2)
filename = parse.unquote(filename)
except (ValueError, LookupError):
filename = force_str(filename_parm['filename'])
return filename

View File

@ -78,8 +78,11 @@ class OR:
def has_object_permission(self, request, view, obj): def has_object_permission(self, request, view, obj):
return ( return (
self.op1.has_object_permission(request, view, obj) or self.op1.has_permission(request, view)
self.op2.has_object_permission(request, view, obj) and self.op1.has_object_permission(request, view, obj)
) or (
self.op2.has_permission(request, view)
and self.op2.has_object_permission(request, view, obj)
) )

View File

@ -14,7 +14,6 @@ from django import forms
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import Page from django.core.paginator import Page
from django.http.multipartparser import parse_header
from django.template import engines, loader from django.template import engines, loader
from django.urls import NoReverseMatch from django.urls import NoReverseMatch
from django.utils.html import mark_safe from django.utils.html import mark_safe
@ -22,7 +21,7 @@ from django.utils.html import mark_safe
from rest_framework import VERSION, exceptions, serializers, status from rest_framework import VERSION, exceptions, serializers, status
from rest_framework.compat import ( from rest_framework.compat import (
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema, INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, coreschema,
pygments_css, yaml parse_header_parameters, pygments_css, yaml
) )
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.request import is_form_media_type, override_method from rest_framework.request import is_form_media_type, override_method
@ -72,7 +71,7 @@ class JSONRenderer(BaseRenderer):
# If the media type looks like 'application/json; indent=4', # If the media type looks like 'application/json; indent=4',
# then pretty print the result. # then pretty print the result.
# Note that we coerce `indent=0` into `indent=None`. # Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii')) base_media_type, params = parse_header_parameters(accepted_media_type)
try: try:
return zero_as_none(max(min(int(params['indent']), 8), 0)) return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError): except (KeyError, ValueError, TypeError):

View File

@ -14,11 +14,11 @@ from contextlib import contextmanager
from django.conf import settings from django.conf import settings
from django.http import HttpRequest, QueryDict from django.http import HttpRequest, QueryDict
from django.http.multipartparser import parse_header
from django.http.request import RawPostDataException from django.http.request import RawPostDataException
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import exceptions
from rest_framework.compat import parse_header_parameters
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -26,7 +26,7 @@ def is_form_media_type(media_type):
""" """
Return True if the media type is a valid form media type. Return True if the media type is a valid form media type.
""" """
base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING)) base_media_type, params = parse_header_parameters(media_type)
return (base_media_type == 'application/x-www-form-urlencoded' or return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data') base_media_type == 'multipart/form-data')

View File

@ -198,7 +198,11 @@ class SchemaGenerator(BaseSchemaGenerator):
if is_custom_action(action): if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/" # Custom action, eg "/users/{pk}/activate/", "/users/active/"
if len(view.action_map) > 1: mapped_methods = {
# Don't count head mapping, e.g. not part of the schema
method for method in view.action_map if method != 'head'
}
if len(mapped_methods) > 1:
action = self.default_mapping[method.lower()] action = self.default_mapping[method.lower()]
if action in self.coerce_method_names: if action in self.coerce_method_names:
action = self.coerce_method_names[action] action = self.coerce_method_names[action]

View File

@ -13,7 +13,7 @@ from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from rest_framework import ( from rest_framework import (
RemovedInDRF314Warning, exceptions, renderers, serializers RemovedInDRF315Warning, exceptions, renderers, serializers
) )
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty from rest_framework.fields import _UnvalidatedField, empty
@ -713,106 +713,10 @@ class AutoSchema(ViewInspector):
return [path.split('/')[0].replace('_', '-')] return [path.split('/')[0].replace('_', '-')]
def _get_path_parameters(self, path, method):
warnings.warn(
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_path_parameters(path, method)
def _get_filter_parameters(self, path, method):
warnings.warn(
"Method `_get_filter_parameters()` has been renamed to `get_filter_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_filter_parameters(path, method)
def _get_responses(self, path, method):
warnings.warn(
"Method `_get_responses()` has been renamed to `get_responses()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_responses(path, method)
def _get_request_body(self, path, method):
warnings.warn(
"Method `_get_request_body()` has been renamed to `get_request_body()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_request_body(path, method)
def _get_serializer(self, path, method):
warnings.warn(
"Method `_get_serializer()` has been renamed to `get_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_serializer(path, method)
def _get_paginator(self):
warnings.warn(
"Method `_get_paginator()` has been renamed to `get_paginator()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_paginator()
def _map_field_validators(self, field, schema):
warnings.warn(
"Method `_map_field_validators()` has been renamed to `map_field_validators()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_field_validators(field, schema)
def _map_serializer(self, serializer):
warnings.warn(
"Method `_map_serializer()` has been renamed to `map_serializer()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_serializer(serializer)
def _map_field(self, field):
warnings.warn(
"Method `_map_field()` has been renamed to `map_field()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_field(field)
def _map_choicefield(self, field):
warnings.warn(
"Method `_map_choicefield()` has been renamed to `map_choicefield()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.map_choicefield(field)
def _get_pagination_parameters(self, path, method):
warnings.warn(
"Method `_get_pagination_parameters()` has been renamed to `get_pagination_parameters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.get_pagination_parameters(path, method)
def _allows_filters(self, path, method):
warnings.warn(
"Method `_allows_filters()` has been renamed to `allows_filters()`. "
"The old name will be removed in DRF v3.14.",
RemovedInDRF314Warning, stacklevel=2
)
return self.allows_filters(path, method)
def _get_reference(self, serializer): def _get_reference(self, serializer):
warnings.warn( warnings.warn(
"Method `_get_reference()` has been renamed to `get_reference()`. " "Method `_get_reference()` has been renamed to `get_reference()`. "
"The old name will be removed in DRF v3.14.", "The old name will be removed in DRF v3.15.",
RemovedInDRF314Warning, stacklevel=2 RemovedInDRF315Warning, stacklevel=2
) )
return self.get_reference(serializer) return self.get_reference(serializer)

View File

@ -52,7 +52,7 @@ from rest_framework.fields import ( # NOQA # isort:skip
BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField, BooleanField, CharField, ChoiceField, DateField, DateTimeField, DecimalField,
DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField, DictField, DurationField, EmailField, Field, FileField, FilePathField, FloatField,
HiddenField, HStoreField, IPAddressField, ImageField, IntegerField, JSONField, HiddenField, HStoreField, IPAddressField, ImageField, IntegerField, JSONField,
ListField, ModelField, MultipleChoiceField, NullBooleanField, ReadOnlyField, ListField, ModelField, MultipleChoiceField, ReadOnlyField,
RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField, RegexField, SerializerMethodField, SlugField, TimeField, URLField, UUIDField,
) )
from rest_framework.relations import ( # NOQA # isort:skip from rest_framework.relations import ( # NOQA # isort:skip
@ -216,7 +216,7 @@ class BaseSerializer(Field):
return self.instance return self.instance
def is_valid(self, raise_exception=False): def is_valid(self, *, raise_exception=False):
assert hasattr(self, 'initial_data'), ( assert hasattr(self, 'initial_data'), (
'Cannot call `.is_valid()` as no `data=` keyword argument was ' 'Cannot call `.is_valid()` as no `data=` keyword argument was '
'passed when instantiating the serializer instance.' 'passed when instantiating the serializer instance.'
@ -735,7 +735,7 @@ class ListSerializer(BaseSerializer):
return self.instance return self.instance
def is_valid(self, raise_exception=False): def is_valid(self, *, raise_exception=False):
# This implementation is the same as the default, # This implementation is the same as the default,
# except that we use lists, rather than dicts, as the empty case. # except that we use lists, rather than dicts, as the empty case.
assert hasattr(self, 'initial_data'), ( assert hasattr(self, 'initial_data'), (

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -277,7 +277,7 @@ class APIClient(APIRequestFactory, DjangoClient):
""" """
self.handler._force_user = user self.handler._force_user = user
self.handler._force_token = token self.handler._force_token = token
if user is None: if user is None and token is None:
self.logout() # Also clear any possible session info if required self.logout() # Also clear any possible session info if required
def request(self, **kwargs): def request(self, **kwargs):

View File

@ -3,9 +3,7 @@ Handling of media types, as found in HTTP Content-Type and Accept headers.
See https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7 See https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html#sec3.7
""" """
from django.http.multipartparser import parse_header from rest_framework.compat import parse_header_parameters
from rest_framework import HTTP_HEADER_ENCODING
def media_type_matches(lhs, rhs): def media_type_matches(lhs, rhs):
@ -46,7 +44,7 @@ def order_by_precedence(media_type_lst):
class _MediaType: class _MediaType:
def __init__(self, media_type_str): def __init__(self, media_type_str):
self.orig = '' if (media_type_str is None) else media_type_str self.orig = '' if (media_type_str is None) else media_type_str
self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING)) self.full_type, self.params = parse_header_parameters(self.orig)
self.main_type, sep, self.sub_type = self.full_type.partition('/') self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other): def match(self, other):
@ -79,5 +77,5 @@ class _MediaType:
def __str__(self): def __str__(self):
ret = "%s/%s" % (self.main_type, self.sub_type) ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items(): for key, val in self.params.items():
ret += "; %s=%s" % (key, val.decode('ascii')) ret += "; %s=%s" % (key, val)
return ret return ret

View File

@ -198,6 +198,10 @@ class ViewSetMixin:
for action in actions: for action in actions:
try: try:
url_name = '%s-%s' % (self.basename, action.url_name) url_name = '%s-%s' % (self.basename, action.url_name)
namespace = self.request.resolver_match.namespace
if namespace:
url_name = '%s:%s' % (namespace, url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request) url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs) view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url action_urls[view.get_view_name()] = url

View File

@ -1,11 +1,11 @@
[metadata] [metadata]
license_file = LICENSE.md license_files = LICENSE.md
[tool:pytest] [tool:pytest]
addopts=--tb=short --strict-markers -ra addopts=--tb=short --strict-markers -ra
[flake8] [flake8]
ignore = E501,W504 ignore = E501,W503,W504
banned-modules = json = use from rest_framework.utils import json! banned-modules = json = use from rest_framework.utils import json!
[isort] [isort]

View File

@ -219,8 +219,8 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with CSRF token succeeds. Ensure POSTing form over session authentication with CSRF token succeeds.
Regression test for #6088 Regression test for #6088
""" """
# Remove this shim when dropping support for Django 2.2. # Remove this shim when dropping support for Django 3.0.
if django.VERSION < (3, 0): if django.VERSION < (3, 1):
from django.middleware.csrf import _get_new_csrf_token from django.middleware.csrf import _get_new_csrf_token
else: else:
from django.middleware.csrf import ( from django.middleware.csrf import (

View File

@ -754,6 +754,67 @@ class TestSchemaGeneratorWithManyToMany(TestCase):
assert schema == expected assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class TestSchemaGeneratorActionKeysViewSets(TestCase):
def test_action_not_coerced_for_get_and_head(self):
"""
Ensure that action name is preserved when action map contains "head".
"""
class CustomViewSet(GenericViewSet):
serializer_class = EmptySerializer
@action(methods=['get', 'head'], detail=True)
def custom_read(self, request, pk):
raise NotImplementedError
@action(methods=['put', 'patch'], detail=True)
def custom_mixed_update(self, request, pk):
raise NotImplementedError
self.router = DefaultRouter()
self.router.register('example', CustomViewSet, basename='example')
self.patterns = [
path('', include(self.router.urls))
]
generator = SchemaGenerator(title='Example API', patterns=self.patterns)
schema = generator.get_schema()
expected = coreapi.Document(
url='',
title='Example API',
content={
'example': {
'custom_read': coreapi.Link(
url='/example/{id}/custom_read/',
action='get',
fields=[
coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
]
),
'custom_mixed_update': {
'update': coreapi.Link(
url='/example/{id}/custom_mixed_update/',
action='put',
fields=[
coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
]
),
'partial_update': coreapi.Link(
url='/example/{id}/custom_mixed_update/',
action='patch',
fields=[
coreapi.Field('id', required=True, location='path', schema=coreschema.String()),
]
)
}
}
}
)
assert schema == expected
@unittest.skipUnless(coreapi, 'coreapi is not installed') @unittest.skipUnless(coreapi, 'coreapi is not installed')
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
class Test4605Regression(TestCase): class Test4605Regression(TestCase):

View File

@ -52,9 +52,9 @@ class GenerateSchemaTests(TestCase):
@pytest.mark.skipif(yaml is None, reason='PyYAML is required.') @pytest.mark.skipif(yaml is None, reason='PyYAML is required.')
def test_renders_default_schema_with_custom_title_url_and_description(self): def test_renders_default_schema_with_custom_title_url_and_description(self):
call_command('generateschema', call_command('generateschema',
'--title=SampleAPI', '--title=ExampleAPI',
'--url=http://api.sample.com', '--url=http://api.example.com',
'--description=Sample description', '--description=Example description',
stdout=self.out) stdout=self.out)
# Check valid YAML was output. # Check valid YAML was output.
schema = yaml.safe_load(self.out.getvalue()) schema = yaml.safe_load(self.out.getvalue())
@ -94,8 +94,8 @@ class GenerateSchemaTests(TestCase):
@override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}) @override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'})
def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self): def test_coreapi_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info: expected_out = """info:
description: Sample description description: Example description
title: SampleAPI title: ExampleAPI
version: '' version: ''
openapi: 3.0.0 openapi: 3.0.0
paths: paths:
@ -103,12 +103,12 @@ class GenerateSchemaTests(TestCase):
get: get:
operationId: list operationId: list
servers: servers:
- url: http://api.sample.com/ - url: http://api.example.com/
""" """
call_command('generateschema', call_command('generateschema',
'--title=SampleAPI', '--title=ExampleAPI',
'--url=http://api.sample.com', '--url=http://api.example.com',
'--description=Sample description', '--description=Example description',
stdout=self.out) stdout=self.out)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue()) self.assertIn(formatting.dedent(expected_out), self.out.getvalue())

View File

@ -566,7 +566,7 @@ class TestCreateOnlyDefault:
def test_create_only_default_callable_sets_context(self): def test_create_only_default_callable_sets_context(self):
""" """
CreateOnlyDefault instances with a callable default should set_context CreateOnlyDefault instances with a callable default should set context
on the callable if possible on the callable if possible
""" """
class TestCallableDefault: class TestCallableDefault:
@ -679,9 +679,9 @@ class TestBooleanField(FieldValues):
assert exc_info.value.detail == expected assert exc_info.value.detail == expected
class TestNullBooleanField(TestBooleanField): class TestNullableBooleanField(TestBooleanField):
""" """
Valid and invalid values for `NullBooleanField`. Valid and invalid values for `BooleanField` when `allow_null=True`.
""" """
valid_inputs = { valid_inputs = {
'true': True, 'true': True,
@ -706,16 +706,6 @@ class TestNullBooleanField(TestBooleanField):
field = serializers.BooleanField(allow_null=True) field = serializers.BooleanField(allow_null=True)
class TestNullableBooleanField(TestNullBooleanField):
"""
Valid and invalid values for `BooleanField` when `allow_null=True`.
"""
@property
def field(self):
return serializers.BooleanField(allow_null=True)
# String types... # String types...
class TestCharField(FieldValues): class TestCharField(FieldValues):

View File

@ -635,7 +635,7 @@ class PermissionsCompositionTests(TestCase):
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is True assert hasperm is True
assert mock_deny.call_count == 1 assert mock_deny.call_count == 0
assert mock_allow.call_count == 1 assert mock_allow.call_count == 1
def test_and_lazyness(self): def test_and_lazyness(self):
@ -677,3 +677,16 @@ class PermissionsCompositionTests(TestCase):
assert hasperm is False assert hasperm is False
assert mock_deny.call_count == 1 assert mock_deny.call_count == 1
mock_allow.assert_not_called() mock_allow.assert_not_called()
def test_unimplemented_has_object_permission(self):
"test for issue 6402 https://github.com/encode/django-rest-framework/issues/6402"
request = factory.get('/1', format='json')
request.user = AnonymousUser()
class IsAuthenticatedUserOwner(permissions.IsAuthenticated):
def has_object_permission(self, request, view, obj):
return True
composed_perm = (IsAuthenticatedUserOwner | permissions.IsAdminUser)
hasperm = composed_perm().has_object_permission(request, None, None)
assert hasperm is False

View File

@ -10,6 +10,7 @@ from django.test import TestCase, override_settings
from django.urls import path from django.urls import path
from rest_framework import fields, serializers from rest_framework import fields, serializers
from rest_framework.authtoken.models import Token
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import ( from rest_framework.test import (
@ -19,10 +20,12 @@ from rest_framework.test import (
@api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) @api_view(['GET', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS'])
def view(request): def view(request):
return Response({ data = {'auth': request.META.get('HTTP_AUTHORIZATION', b'')}
'auth': request.META.get('HTTP_AUTHORIZATION', b''), if request.user:
'user': request.user.username data['user'] = request.user.username
}) if request.auth:
data['token'] = request.auth.key
return Response(data)
@api_view(['GET', 'POST']) @api_view(['GET', 'POST'])
@ -78,14 +81,46 @@ class TestAPITestClient(TestCase):
response = self.client.get('/view/') response = self.client.get('/view/')
assert response.data['auth'] == 'example' assert response.data['auth'] == 'example'
def test_force_authenticate(self): def test_force_authenticate_with_user(self):
""" """
Setting `.force_authenticate()` forcibly authenticates each request. Setting `.force_authenticate()` with a user forcibly authenticates each
request with that user.
""" """
user = User.objects.create_user('example', 'example@example.com') user = User.objects.create_user('example', 'example@example.com')
self.client.force_authenticate(user)
self.client.force_authenticate(user=user)
response = self.client.get('/view/') response = self.client.get('/view/')
assert response.data['user'] == 'example' assert response.data['user'] == 'example'
assert 'token' not in response.data
def test_force_authenticate_with_token(self):
"""
Setting `.force_authenticate()` with a token forcibly authenticates each
request with that token.
"""
user = User.objects.create_user('example', 'example@example.com')
token = Token.objects.create(key='xyz', user=user)
self.client.force_authenticate(token=token)
response = self.client.get('/view/')
assert response.data['token'] == 'xyz'
assert 'user' not in response.data
def test_force_authenticate_with_user_and_token(self):
"""
Setting `.force_authenticate()` with a user and token forcibly
authenticates each request with that user and token.
"""
user = User.objects.create_user('example', 'example@example.com')
token = Token.objects.create(key='xyz', user=user)
self.client.force_authenticate(user=user, token=token)
response = self.client.get('/view/')
assert response.data['user'] == 'example'
assert response.data['token'] == 'xyz'
def test_force_authenticate_with_sessions(self): def test_force_authenticate_with_sessions(self):
""" """
@ -102,8 +137,9 @@ class TestAPITestClient(TestCase):
response = self.client.get('/session-view/') response = self.client.get('/session-view/')
assert response.data['active_session'] is True assert response.data['active_session'] is True
# Force authenticating as `None` should also logout the user session. # Force authenticating with `None` user and token should also logout
self.client.force_authenticate(None) # the user session.
self.client.force_authenticate(user=None, token=None)
response = self.client.get('/session-view/') response = self.client.get('/session-view/')
assert response.data['active_session'] is False assert response.data['active_session'] is False

View File

@ -1,6 +1,6 @@
[tox] [tox]
envlist = envlist =
{py36,py37,py38,py39}-django22, {py36,py37,py38,py39}-django30,
{py36,py37,py38,py39}-django31, {py36,py37,py38,py39}-django31,
{py36,py37,py38,py39,py310}-django32, {py36,py37,py38,py39,py310}-django32,
{py38,py39,py310}-{django40,django41,djangomain}, {py38,py39,py310}-{django40,django41,djangomain},
@ -8,7 +8,7 @@ envlist =
[travis:env] [travis:env]
DJANGO = DJANGO =
2.2: django22 3.0: django30
3.1: django31 3.1: django31
3.2: django32 3.2: django32
4.0: django40 4.0: django40
@ -22,7 +22,7 @@ setenv =
PYTHONDONTWRITEBYTECODE=1 PYTHONDONTWRITEBYTECODE=1
PYTHONWARNINGS=once PYTHONWARNINGS=once
deps = deps =
django22: Django>=2.2,<3.0 django30: Django>=3.0,<3.1
django31: Django>=3.1,<3.2 django31: Django>=3.1,<3.2
django32: Django>=3.2,<4.0 django32: Django>=3.2,<4.0
django40: Django>=4.0,<4.1 django40: Django>=4.0,<4.1