Merge branch 'master' into contextvars-for-request-context

This commit is contained in:
Daler 2023-06-21 14:48:41 +05:00 committed by GitHub
commit 6a7df3bb09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1204 additions and 262 deletions

View File

@ -8,17 +8,17 @@ on:
jobs:
pre-commit:
runs-on: ubuntu-20.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: "3.10"
- uses: pre-commit/action@v2.0.0
- uses: pre-commit/action@v3.0.0
with:
token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -55,7 +55,7 @@ There is a live example API for testing purposes, [available here][sandbox].
# Requirements
* Python 3.6+
* Django 4.1, 4.0, 3.2, 3.1, 3.0
* Django 4.2, 4.1, 4.0, 3.2, 3.1, 3.0
We **highly recommend** and only officially support the latest patch release of
each Python and Django series.

View File

@ -165,7 +165,7 @@ This permission is suitable if you want your API to only be accessible to a subs
## IsAuthenticatedOrReadOnly
The `IsAuthenticatedOrReadOnly` will allow authenticated users to perform any request. Requests for unauthorised users will only be permitted if the request method is one of the "safe" methods; `GET`, `HEAD` or `OPTIONS`.
The `IsAuthenticatedOrReadOnly` will allow authenticated users to perform any request. Requests for unauthenticated users will only be permitted if the request method is one of the "safe" methods; `GET`, `HEAD` or `OPTIONS`.
This permission is suitable if you want to your API to allow read permissions to anonymous users, and only allow write permissions to authenticated users.

View File

@ -54,5 +54,5 @@ As with the `reverse` function, you should **include the request as a keyword ar
api_root = reverse_lazy('api-root', request=request)
[cite]: https://www.ics.uci.edu/~fielding/pubs/dissertation/rest_arch_style.htm#sec_5_1_5
[reverse]: https://docs.djangoproject.com/en/stable/topics/http/urls/#reverse
[reverse-lazy]: https://docs.djangoproject.com/en/stable/topics/http/urls/#reverse-lazy
[reverse]: https://docs.djangoproject.com/en/stable/ref/urlresolvers/#reverse
[reverse-lazy]: https://docs.djangoproject.com/en/stable/ref/urlresolvers/#reverse-lazy

View File

@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now
The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field.
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below.
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality.
---
@ -295,13 +295,14 @@ To write a class-based validator, use the `__call__` method. Class-based validat
In some advanced cases you might want a validator to be passed the serializer
field it is being used with as additional context. You can do so by setting
a `requires_context = True` attribute on the validator. The `__call__` method
a `requires_context = True` attribute on the validator class. The `__call__` method
will then be called with the `serializer_field`
or `serializer` as an additional argument.
requires_context = True
class MultipleOf:
requires_context = True
def __call__(self, value, serializer_field):
...
def __call__(self, value, serializer_field):
...
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/

View File

@ -143,6 +143,16 @@ class PublisherSearchView(generics.ListAPIView):
---
## Deprecations
### `serializers.NullBooleanField`
`serializers.NullBooleanField` is now pending deprecation, and will be removed in 3.14.
Instead use `serializers.BooleanField` field and set `allow_null=True` which does the same thing.
---
## Funding
REST framework is a *collaboratively funded project*. If you use

View File

@ -60,3 +60,13 @@ See Pull Request [#7522](https://github.com/encode/django-rest-framework/pull/75
## Minor fixes and improvements
There are a number of minor fixes and improvements in this release. See the [release notes](release-notes.md) page for a complete listing.
---
## Deprecations
### `serializers.NullBooleanField`
`serializers.NullBooleanField` was moved to pending deprecation in 3.12, and deprecated in 3.13. It has now been removed from the core framework.
Instead use `serializers.BooleanField` field and set `allow_null=True` which does the same thing.

View File

@ -157,6 +157,7 @@ Date: 28th September 2020
* Fix `PrimaryKeyRelatedField` and `HyperlinkedRelatedField` when source field is actually a property. [#7142]
* `Token.generate_key` is now a class method. [#7502]
* `@action` warns if method is wrapped in a decorator that does not preserve information using `@functools.wraps`. [#7098]
* Deprecate `serializers.NullBooleanField` in favour of `serializers.BooleanField` with `allow_null=True` [#7122]
---

View File

@ -86,7 +86,7 @@ continued development by **[signing up for a paid plan][funding]**.
REST framework requires the following:
* Python (3.6, 3.7, 3.8, 3.9, 3.10, 3.11)
* Django (2.2, 3.0, 3.1, 3.2, 4.0, 4.1)
* Django (3.0, 3.1, 3.2, 4.0, 4.1)
We **highly recommend** and only officially support the latest patch release of
each Python and Django series.

View File

@ -1,8 +1,8 @@
# Wheel for PyPI installs.
wheel>=0.35.1,<0.36
wheel>=0.36.2,<0.40.0
# Twine for secured PyPI uploads.
twine>=3.2.0,<3.3
twine>=3.4.2,<4.0.2
# Transifex client for managing translation resources.
transifex-client

View File

@ -13,7 +13,7 @@ __title__ = 'Django REST framework'
__version__ = '3.14.0'
__author__ = 'Tom Christie'
__license__ = 'BSD 3-Clause'
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
__copyright__ = 'Copyright 2011-2023 Encode OSS Ltd'
# Version synonym
VERSION = __version__
@ -31,3 +31,7 @@ if django.VERSION < (3, 2):
class RemovedInDRF315Warning(DeprecationWarning):
pass
class RemovedInDRF317Warning(PendingDeprecationWarning):
pass

View File

@ -4,6 +4,7 @@ from django.contrib.admin.views.main import ChangeList
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from rest_framework.authtoken.models import Token, TokenProxy
@ -23,6 +24,8 @@ class TokenChangeList(ChangeList):
class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created')
fields = ('user',)
search_fields = ('user__username',)
search_help_text = _('Username')
ordering = ('-created',)
actions = None # Actions not compatible with mapped IDs.
autocomplete_fields = ("user",)

View File

@ -4,9 +4,9 @@ import datetime
import decimal
import functools
import inspect
import logging
import re
import uuid
from collections import OrderedDict
from collections.abc import Mapping
from django.conf import settings
@ -17,6 +17,7 @@ from django.core.validators import (
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators
)
from django.db.models import IntegerChoices, TextChoices
from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField
from django.utils import timezone
@ -28,15 +29,22 @@ from django.utils.encoding import is_protected_type, smart_str
from django.utils.formats import localize_input, sanitize_separators
from django.utils.ipv6 import clean_ipv6_address
from django.utils.translation import gettext_lazy as _
from pytz.exceptions import InvalidTimeError
try:
import pytz
except ImportError:
pytz = None
from rest_framework import ISO_8601
from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, json, representation
from rest_framework.utils.formatting import lazy_format
from rest_framework.utils.timezone import valid_datetime
from rest_framework.validators import ProhibitSurrogateCharactersValidator
logger = logging.getLogger("rest_framework.fields")
class empty:
"""
@ -109,27 +117,6 @@ def get_attribute(instance, attrs):
return instance
def set_value(dictionary, keys, value):
"""
Similar to Python's built in `dictionary[key] = value`,
but takes a list of nested keys instead of a single key.
set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
if not keys:
dictionary.update(value)
return
for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
def to_choices_dict(choices):
"""
Convert choices into key/value dicts.
@ -142,7 +129,7 @@ def to_choices_dict(choices):
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
# choices = [('Category', ((1, 'First'), (2, 'Second'))), (3, 'Third')]
ret = OrderedDict()
ret = {}
for choice in choices:
if not isinstance(choice, (list, tuple)):
# single choice
@ -165,7 +152,7 @@ def flatten_choices_dict(choices):
flatten_choices_dict({1: '1st', 2: '2nd'}) -> {1: '1st', 2: '2nd'}
flatten_choices_dict({'Group': {1: '1st', 2: '2nd'}}) -> {1: '1st', 2: '2nd'}
"""
ret = OrderedDict()
ret = {}
for key, value in choices.items():
if isinstance(value, dict):
# grouped choices (category, sub choices)
@ -677,22 +664,27 @@ class BooleanField(Field):
default_empty_html = False
initial = False
TRUE_VALUES = {
't', 'T',
'y', 'Y', 'yes', 'Yes', 'YES',
'true', 'True', 'TRUE',
'on', 'On', 'ON',
'1', 1,
True
't',
'y',
'yes',
'true',
'on',
'1',
1,
True,
}
FALSE_VALUES = {
'f', 'F',
'n', 'N', 'no', 'No', 'NO',
'false', 'False', 'FALSE',
'off', 'Off', 'OFF',
'0', 0, 0.0,
False
'f',
'n',
'no',
'false',
'off',
'0',
0,
0.0,
False,
}
NULL_VALUES = {'null', 'Null', 'NULL', '', None}
NULL_VALUES = {'null', '', None}
def __init__(self, **kwargs):
if kwargs.get('allow_null', False):
@ -700,22 +692,28 @@ class BooleanField(Field):
self.initial = None
super().__init__(**kwargs)
@staticmethod
def _lower_if_str(value):
if isinstance(value, str):
return value.lower()
return value
def to_internal_value(self, data):
with contextlib.suppress(TypeError):
if data in self.TRUE_VALUES:
if self._lower_if_str(data) in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
elif self._lower_if_str(data) in self.FALSE_VALUES:
return False
elif data in self.NULL_VALUES and self.allow_null:
elif self._lower_if_str(data) in self.NULL_VALUES and self.allow_null:
return None
self.fail('invalid', input=data)
self.fail("invalid", input=data)
def to_representation(self, value):
if value in self.TRUE_VALUES:
if self._lower_if_str(value) in self.TRUE_VALUES:
return True
elif value in self.FALSE_VALUES:
elif self._lower_if_str(value) in self.FALSE_VALUES:
return False
if value in self.NULL_VALUES and self.allow_null:
if self._lower_if_str(value) in self.NULL_VALUES and self.allow_null:
return None
return bool(value)
@ -989,6 +987,11 @@ class DecimalField(Field):
self.max_value = max_value
self.min_value = min_value
if self.max_value is not None and not isinstance(self.max_value, decimal.Decimal):
logger.warning("max_value in DecimalField should be Decimal type.")
if self.min_value is not None and not isinstance(self.min_value, decimal.Decimal):
logger.warning("min_value in DecimalField should be Decimal type.")
if self.max_digits is not None and self.decimal_places is not None:
self.max_whole_digits = self.max_digits - self.decimal_places
else:
@ -1154,9 +1157,16 @@ class DateTimeField(Field):
except OverflowError:
self.fail('overflow')
try:
return timezone.make_aware(value, field_timezone)
except InvalidTimeError:
self.fail('make_aware', timezone=field_timezone)
dt = timezone.make_aware(value, field_timezone)
# When the resulting datetime is a ZoneInfo instance, it won't necessarily
# throw given an invalid datetime, so we need to specifically check.
if not valid_datetime(dt):
self.fail('make_aware', timezone=field_timezone)
return dt
except Exception as e:
if pytz and isinstance(e, pytz.exceptions.InvalidTimeError):
self.fail('make_aware', timezone=field_timezone)
raise e
elif (field_timezone is None) and timezone.is_aware(value):
return timezone.make_naive(value, datetime.timezone.utc)
return value
@ -1392,6 +1402,10 @@ class ChoiceField(Field):
if data == '' and self.allow_blank:
return ''
if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
str(data.value):
data = data.value
try:
return self.choice_strings_to_values[str(data)]
except KeyError:
@ -1400,6 +1414,11 @@ class ChoiceField(Field):
def to_representation(self, value):
if value in ('', None):
return value
if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
str(value.value):
value = value.value
return self.choice_strings_to_values.get(str(value), value)
def iter_options(self):
@ -1423,7 +1442,8 @@ class ChoiceField(Field):
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = {
str(key): key for key in self.choices
str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
and str(key) != str(key.value) else str(key): key for key in self.choices
}
choices = property(_get_choices, _set_choices)
@ -1643,7 +1663,7 @@ class ListField(Field):
def run_child_validation(self, data):
result = []
errors = OrderedDict()
errors = {}
for idx, item in enumerate(data):
try:
@ -1707,7 +1727,7 @@ class DictField(Field):
def run_child_validation(self, data):
result = {}
errors = OrderedDict()
errors = {}
for key, value in data.items():
key = str(key)

View File

@ -3,6 +3,7 @@ Provides generic filtering backends that can be used to filter the results
returned by list views.
"""
import operator
import warnings
from functools import reduce
from django.core.exceptions import ImproperlyConfigured
@ -12,6 +13,7 @@ from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema, distinct
from rest_framework.settings import api_settings
@ -29,6 +31,8 @@ class BaseFilterBackend:
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return []
@ -146,6 +150,8 @@ class SearchFilter(BaseFilterBackend):
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
@ -306,6 +312,8 @@ class OrderingFilter(BaseFilterBackend):
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(

View File

@ -6,8 +6,6 @@ some fairly ad-hoc information about the view.
Future implementations might use JSON schema or other definitions in order
to return this information in a more standardized way.
"""
from collections import OrderedDict
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.encoding import force_str
@ -59,11 +57,12 @@ class SimpleMetadata(BaseMetadata):
})
def determine_metadata(self, request, view):
metadata = OrderedDict()
metadata['name'] = view.get_view_name()
metadata['description'] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in view.parser_classes]
metadata = {
"name": view.get_view_name(),
"description": view.get_view_description(),
"renders": [renderer.media_type for renderer in view.renderer_classes],
"parses": [parser.media_type for parser in view.parser_classes],
}
if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view)
if actions:
@ -106,25 +105,27 @@ class SimpleMetadata(BaseMetadata):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
return OrderedDict([
(field_name, self.get_field_info(field))
return {
field_name: self.get_field_info(field)
for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField)
])
}
def get_field_info(self, field):
"""
Given an instance of a serializer field, return a dictionary
of metadata about it.
"""
field_info = OrderedDict()
field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False)
field_info = {
"type": self.label_lookup[field],
"required": getattr(field, "required", False),
}
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value'
'min_value', 'max_value',
'max_digits', 'decimal_places'
]
for attr in attrs:

View File

@ -4,16 +4,19 @@ be used for paginated responses.
"""
import contextlib
import warnings
from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple
from collections import namedtuple
from urllib import parse
from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.db.models import Q
from django.template import loader
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import NotFound
from rest_framework.response import Response
@ -151,6 +154,8 @@ class BasePagination:
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
return []
def get_schema_operation_parameters(self, view):
@ -224,12 +229,12 @@ class PageNumberPagination(BasePagination):
return page_number
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response({
'count': self.page.paginator.count,
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data,
})
def get_paginated_response_schema(self, schema):
return {
@ -310,6 +315,8 @@ class PageNumberPagination(BasePagination):
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(
@ -394,12 +401,12 @@ class LimitOffsetPagination(BasePagination):
return list(queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response({
'count': self.count,
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data
})
def get_paginated_response_schema(self, schema):
return {
@ -524,6 +531,8 @@ class LimitOffsetPagination(BasePagination):
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [
coreapi.Field(
@ -620,7 +629,7 @@ class CursorPagination(BasePagination):
queryset = queryset.order_by(*self.ordering)
# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
if str(current_position) != 'None':
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
@ -631,7 +640,12 @@ class CursorPagination(BasePagination):
else:
kwargs = {order_attr + '__gt': current_position}
queryset = queryset.filter(**kwargs)
filter_query = Q(**kwargs)
# If some records contain a null for the ordering field, don't lose them.
# When reverse ordering, nulls will come last and need to be included.
if (reverse and not is_reversed) or is_reversed:
filter_query |= Q(**{order_attr + '__isnull': True})
queryset = queryset.filter(filter_query)
# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
@ -704,7 +718,7 @@ class CursorPagination(BasePagination):
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
has_item_with_unique_position = position is not None
break
# The item in this position has the same position as the item
@ -757,7 +771,7 @@ class CursorPagination(BasePagination):
# The item in this position and the item following it
# have different positions. We can use this position as
# our marker.
has_item_with_unique_position = True
has_item_with_unique_position = position is not None
break
# The item in this position has the same position as the item
@ -795,6 +809,10 @@ class CursorPagination(BasePagination):
"""
Return a tuple of strings, that may be used in an `order_by` method.
"""
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering')
@ -805,26 +823,19 @@ class CursorPagination(BasePagination):
# then we defer to that filter to determine the ordering.
filter_cls = ordering_filters[0]
filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, (
'Using cursor pagination, but filter class {filter_cls} '
'returned a `None` ordering.'.format(
filter_cls=filter_cls.__name__
)
)
else:
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared '
'on the pagination class.'
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
'nearly-unique field on the model, such as "-created" or "pk".'
)
ordering_from_filter = filter_instance.get_ordering(request, queryset, view)
if ordering_from_filter:
ordering = ordering_from_filter
assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared '
'on the pagination class.'
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
'nearly-unique field on the model, such as "-created" or "pk".'
)
assert isinstance(ordering, (str, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format(
@ -883,14 +894,14 @@ class CursorPagination(BasePagination):
attr = instance[field_name]
else:
attr = getattr(instance, field_name)
return str(attr)
return None if attr is None else str(attr)
def get_paginated_response(self, data):
return Response(OrderedDict([
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response({
'next': self.get_next_link(),
'previous': self.get_previous_link(),
'results': data,
})
def get_paginated_response_schema(self, schema):
return {
@ -927,6 +938,8 @@ class CursorPagination(BasePagination):
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [
coreapi.Field(

View File

@ -1,6 +1,6 @@
import contextlib
import sys
from collections import OrderedDict
from operator import attrgetter
from urllib import parse
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
@ -71,6 +71,7 @@ class PKOnlyObject:
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
def __init__(self, pk):
self.pk = pk
@ -197,13 +198,9 @@ class RelatedField(Field):
if cutoff is not None:
queryset = queryset[:cutoff]
return OrderedDict([
(
self.to_representation(item),
self.display_value(item)
)
for item in queryset
])
return {
self.to_representation(item): self.display_value(item) for item in queryset
}
@property
def choices(self):
@ -464,7 +461,11 @@ class SlugRelatedField(RelatedField):
self.fail('invalid')
def to_representation(self, obj):
return getattr(obj, self.slug_field)
slug = self.slug_field
if "__" in slug:
# handling nested relationship if defined
slug = slug.replace('__', '.')
return attrgetter(slug)(obj)
class ManyRelatedField(Field):

View File

@ -9,7 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API.
import base64
import contextlib
from collections import OrderedDict
import datetime
from urllib import parse
from django import forms
@ -507,6 +507,9 @@ class BrowsableAPIRenderer(BaseRenderer):
return self.render_form_for_serializer(serializer)
def render_form_for_serializer(self, serializer):
if isinstance(serializer, serializers.ListSerializer):
return None
if hasattr(serializer, 'initial_data'):
serializer.is_valid()
@ -556,10 +559,13 @@ class BrowsableAPIRenderer(BaseRenderer):
context['indent'] = 4
# strip HiddenField from output
is_list_serializer = isinstance(serializer, serializers.ListSerializer)
serializer = serializer.child if is_list_serializer else serializer
data = serializer.data.copy()
for name, field in serializer.fields.items():
if isinstance(field, serializers.HiddenField):
data.pop(name, None)
data = [data] if is_list_serializer else data
content = renderer.render(data, accepted, context)
# Renders returns bytes, but CharField expects a str.
content = content.decode()
@ -653,7 +659,7 @@ class BrowsableAPIRenderer(BaseRenderer):
raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
response_headers = OrderedDict(sorted(response.items()))
response_headers = dict(sorted(response.items()))
renderer_content_type = ''
if renderer:
renderer_content_type = '%s' % renderer.media_type
@ -1057,6 +1063,7 @@ class OpenAPIRenderer(BaseRenderer):
def ignore_aliases(self, data):
return True
Dumper.add_representer(SafeString, Dumper.represent_str)
Dumper.add_representer(datetime.timedelta, encoders.CustomScalar.represent_timedelta)
return yaml.dump(data, default_flow_style=False, sort_keys=False, Dumper=Dumper).encode('utf-8')

View File

@ -14,7 +14,7 @@ For example, you might have a `urls.py` that looks something like this:
urlpatterns = router.urls
"""
import itertools
from collections import OrderedDict, namedtuple
from collections import namedtuple
from django.core.exceptions import ImproperlyConfigured
from django.urls import NoReverseMatch, path, re_path
@ -321,7 +321,7 @@ class APIRootView(views.APIView):
def get(self, request, *args, **kwargs):
# Return a plain {"name": "hyperlink"} response.
ret = OrderedDict()
ret = {}
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
@ -365,7 +365,7 @@ class DefaultRouter(SimpleRouter):
"""
Return a basic root view.
"""
api_root_dict = OrderedDict()
api_root_dict = {}
list_name = self.routes[0].name
for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename)

View File

@ -1,11 +1,11 @@
import warnings
from collections import Counter, OrderedDict
from collections import Counter
from urllib import parse
from django.db import models
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework import RemovedInDRF317Warning, exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings
@ -54,7 +54,7 @@ to customise schema structure.
"""
class LinkNode(OrderedDict):
class LinkNode(dict):
def __init__(self):
self.links = []
self.methods_counter = Counter()
@ -118,6 +118,8 @@ class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
assert coreapi, '`coreapi` must be installed for schema support.'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema, '`coreschema` must be installed for schema support.'
super().__init__(title, url, description, patterns, urlconf)
@ -268,11 +270,11 @@ def field_to_schema(field):
)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
properties={
key: field_to_schema(value)
for key, value
in field.fields.items()
]),
},
title=title,
description=description
)
@ -351,6 +353,9 @@ class AutoSchema(ViewInspector):
will be added to auto-generated fields, overwriting on `Field.name`
"""
super().__init__()
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
if manual_fields is None:
manual_fields = []
self._manual_fields = manual_fields
@ -549,7 +554,7 @@ class AutoSchema(ViewInspector):
if not update_with:
return fields
by_name = OrderedDict((f.name, f) for f in fields)
by_name = {f.name: f for f in fields}
for f in update_with:
by_name[f.name] = f
fields = list(by_name.values())
@ -592,6 +597,9 @@ class ManualSchema(ViewInspector):
* `description`: String description for view. Optional.
"""
super().__init__()
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
@ -613,4 +621,6 @@ class ManualSchema(ViewInspector):
def is_enabled():
"""Is CoreAPI Mode enabled?"""
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@ -1,6 +1,5 @@
import re
import warnings
from collections import OrderedDict
from decimal import Decimal
from operator import attrgetter
from urllib.parse import urljoin
@ -340,7 +339,7 @@ class AutoSchema(ViewInspector):
return paginator.get_schema_operation_parameters(view)
def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates
choices = list(dict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices):
type = 'boolean'
elif all(isinstance(choice, int) for choice in choices):

View File

@ -15,7 +15,7 @@ import contextlib
import copy
import inspect
import traceback
from collections import OrderedDict, defaultdict
from collections import defaultdict
from collections.abc import Mapping
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
@ -315,7 +315,7 @@ class SerializerMetaclass(type):
for name, f in base._declared_fields.items() if name not in known
]
return OrderedDict(base_fields + fields)
return dict(base_fields + fields)
def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
@ -353,6 +353,26 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
}
def set_value(self, dictionary, keys, value):
"""
Similar to Python's built in `dictionary[key] = value`,
but takes a list of nested keys instead of a single key.
set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
if not keys:
dictionary.update(value)
return
for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
@cached_property
def fields(self):
"""
@ -400,20 +420,20 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
if hasattr(self, 'initial_data'):
# initial_data may not be a valid type
if not isinstance(self.initial_data, Mapping):
return OrderedDict()
return {}
return OrderedDict([
(field_name, field.get_value(self.initial_data))
return {
field_name: field.get_value(self.initial_data)
for field_name, field in self.fields.items()
if (field.get_value(self.initial_data) is not empty) and
not field.read_only
])
}
return OrderedDict([
(field.field_name, field.get_initial())
return {
field.field_name: field.get_initial()
for field in self.fields.values()
if not field.read_only
])
}
def get_value(self, dictionary):
# We override the default field access in order to support
@ -448,7 +468,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
]
defaults = OrderedDict()
defaults = {}
for field in fields:
try:
default = field.get_default()
@ -481,8 +501,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid')
ret = OrderedDict()
errors = OrderedDict()
ret = {}
errors = {}
fields = self._writable_fields
for field in fields:
@ -499,7 +519,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
except SkipField:
pass
else:
set_value(ret, field.source_attrs, validated_value)
self.set_value(ret, field.source_attrs, validated_value)
if errors:
raise ValidationError(errors)
@ -510,7 +530,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
"""
Object instance -> Dict of primitive datatypes.
"""
ret = OrderedDict()
ret = {}
fields = self._readable_fields
for field in fields:
@ -596,6 +616,12 @@ class ListSerializer(BaseSerializer):
self.min_length = kwargs.pop('min_length', None)
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
instance = kwargs.get('instance', [])
data = kwargs.get('data', [])
if instance and data:
assert len(data) == len(instance), 'Data and instance should have same length'
super().__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self)
@ -670,7 +696,13 @@ class ListSerializer(BaseSerializer):
ret = []
errors = []
for item in data:
for idx, item in enumerate(data):
if (
hasattr(self, 'instance')
and self.instance
and len(self.instance) > idx
):
self.child.instance = self.instance[idx]
try:
validated = self.child.run_validation(item)
except ValidationError as exc:
@ -1068,7 +1100,7 @@ class ModelSerializer(Serializer):
)
# Determine the fields that should be included on the serializer.
fields = OrderedDict()
fields = {}
for field_name in field_names:
# If the field is explicitly declared on the class then use that.
@ -1553,16 +1585,16 @@ class ModelSerializer(Serializer):
# which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not
# including those.
field_sources = OrderedDict(
(field.field_name, field.source) for field in self._writable_fields
field_sources = {
field.field_name: field.source for field in self._writable_fields
if (field.source != '*') and ('.' not in field.source)
)
}
# Special Case: Add read_only fields with defaults.
field_sources.update(OrderedDict(
(field.field_name, field.source) for field in self.fields.values()
field_sources.update({
field.field_name: field.source for field in self.fields.values()
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
))
})
# Invert so we can find the serializer field names that correspond to
# the model field names in the unique_together sets. This also allows

File diff suppressed because one or more lines are too long

View File

@ -250,7 +250,7 @@
"csrfToken": "{{ csrf_token }}"
}
</script>
<script src="{% static "rest_framework/js/jquery-3.5.1.min.js" %}"></script>
<script src="{% static "rest_framework/js/jquery-3.6.4.min.js" %}"></script>
<script src="{% static "rest_framework/js/ajax-form.js" %}"></script>
<script src="{% static "rest_framework/js/csrf.js" %}"></script>
<script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>

View File

@ -293,7 +293,7 @@
"csrfToken": "{% if request %}{{ csrf_token }}{% endif %}"
}
</script>
<script src="{% static "rest_framework/js/jquery-3.5.1.min.js" %}"></script>
<script src="{% static "rest_framework/js/jquery-3.6.4.min.js" %}"></script>
<script src="{% static "rest_framework/js/ajax-form.js" %}"></script>
<script src="{% static "rest_framework/js/csrf.js" %}"></script>
<script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>

View File

@ -30,7 +30,7 @@ being applied unexpectedly?</p>
<p>Your response status code is: <code>{{ response.status_code }}</code></p>
<h3>401 Unauthorised.</h3>
<h3>401 Unauthorized.</h3>
<ul>
<li>Do you have SessionAuthentication enabled?</li>
<li>Are you logged in?</li>
@ -66,6 +66,6 @@ at <code>rest_framework/docs/error.html</code>.</p>
<script src="{% static 'rest_framework/js/jquery-3.5.1.min.js' %}"></script>
<script src="{% static 'rest_framework/js/jquery-3.6.4.min.js' %}"></script>
</body>
</html>

View File

@ -38,7 +38,7 @@
{% include "rest_framework/docs/auth/basic.html" %}
{% include "rest_framework/docs/auth/session.html" %}
<script src="{% static 'rest_framework/js/jquery-3.5.1.min.js' %}"></script>
<script src="{% static 'rest_framework/js/jquery-3.6.4.min.js' %}"></script>
<script src="{% static 'rest_framework/js/bootstrap.min.js' %}"></script>
<script src="{% static 'rest_framework/docs/js/jquery.json-view.min.js' %}"></script>
<script src="{% static 'rest_framework/docs/js/api.js' %}"></script>

View File

@ -11,7 +11,7 @@
{% endif %}
<div class="col-sm-10">
<select multiple {{ field.choices|yesno:",disabled" }} class="form-control" name="{{ field.name }}">
<select multiple class="form-control" name="{{ field.name }}">
{% for select in field.iter_options %}
{% if select.start_option_group %}
<optgroup label="{{ select.label }}">

View File

@ -1,5 +1,4 @@
import re
from collections import OrderedDict
from django import template
from django.template import loader
@ -49,10 +48,10 @@ def with_location(fields, location):
@register.simple_tag
def form_for_link(link):
import coreschema
properties = OrderedDict([
(field.name, field.schema or coreschema.String())
properties = {
field.name: field.schema or coreschema.String()
for field in link.fields
])
}
required = [
field.name
for field in link.fields
@ -272,7 +271,7 @@ def schema_links(section, sec_key=None):
links.update(new_links)
if sec_key is not None:
new_links = OrderedDict()
new_links = {}
for link_key, link in links.items():
new_key = NESTED_FORMAT % (sec_key, link_key)
new_links.update({new_key: link})

View File

@ -65,3 +65,14 @@ class JSONEncoder(json.JSONEncoder):
elif hasattr(obj, '__iter__'):
return tuple(item for item in obj)
return super().default(obj)
class CustomScalar:
"""
CustomScalar that knows how to encode timedelta that renderer
can understand.
"""
@classmethod
def represent_timedelta(cls, dumper, data):
value = str(data.total_seconds())
return dumper.represent_scalar('tag:yaml.org,2002:str', value)

View File

@ -5,7 +5,7 @@ relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import OrderedDict, namedtuple
from collections import namedtuple
FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
@ -58,7 +58,7 @@ def _get_pk(opts):
def _get_fields(opts):
fields = OrderedDict()
fields = {}
for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
fields[field.name] = field
@ -71,9 +71,9 @@ def _get_to_field(field):
def _get_forward_relationships(opts):
"""
Returns an `OrderedDict` of field names to `RelationInfo`.
Returns a dict of field names to `RelationInfo`.
"""
forward_relations = OrderedDict()
forward_relations = {}
for field in [field for field in opts.fields if field.serialize and field.remote_field]:
forward_relations[field.name] = RelationInfo(
model_field=field,
@ -103,9 +103,9 @@ def _get_forward_relationships(opts):
def _get_reverse_relationships(opts):
"""
Returns an `OrderedDict` of field names to `RelationInfo`.
Returns a dict of field names to `RelationInfo`.
"""
reverse_relations = OrderedDict()
reverse_relations = {}
all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many]
for relation in all_related_objects:
accessor_name = relation.get_accessor_name()
@ -139,19 +139,14 @@ def _get_reverse_relationships(opts):
def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict()
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk = {'pk': pk, pk.name: pk}
fields_and_pk.update(fields)
return fields_and_pk
def _merge_relationships(forward_relations, reverse_relations):
return OrderedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
)
return {**forward_relations, **reverse_relations}
def is_abstract_model(model):

View File

@ -1,6 +1,5 @@
import contextlib
import sys
from collections import OrderedDict
from collections.abc import Mapping, MutableMapping
from django.utils.encoding import force_str
@ -8,7 +7,7 @@ from django.utils.encoding import force_str
from rest_framework.utils import json
class ReturnDict(OrderedDict):
class ReturnDict(dict):
"""
Return object from `serializer.data` for the `Serializer` class.
Includes a backlink to the serializer instance for renderers
@ -161,7 +160,7 @@ class BindingDict(MutableMapping):
def __init__(self, serializer):
self.serializer = serializer
self.fields = OrderedDict()
self.fields = {}
def __setitem__(self, key, field):
self.fields[key] = field

View File

@ -0,0 +1,25 @@
from datetime import datetime, timezone, tzinfo
def datetime_exists(dt):
"""Check if a datetime exists. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# There are no non-existent times in UTC, and comparisons between
# aware time zones always compare absolute times; if a datetime is
# not equal to the same datetime represented in UTC, it is imaginary.
return dt.astimezone(timezone.utc) == dt
def datetime_ambiguous(dt: datetime):
"""Check whether a datetime is ambiguous. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# If a datetime exists and its UTC offset changes in response to
# changing `fold`, it is ambiguous in the zone specified.
return datetime_exists(dt) and (
dt.replace(fold=not dt.fold).utcoffset() != dt.utcoffset()
)
def valid_datetime(dt):
"""Returns True if the datetime is not ambiguous or imaginary, False otherwise."""
if isinstance(dt.tzinfo, tzinfo) and not datetime_ambiguous(dt):
return True
return False

View File

@ -79,6 +79,15 @@ class UniqueValidator:
smart_repr(self.queryset)
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.lookup == other.lookup
)
class UniqueTogetherValidator:
"""
@ -166,6 +175,16 @@ class UniqueTogetherValidator:
smart_repr(self.fields)
)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
)
class ProhibitSurrogateCharactersValidator:
message = _('Surrogate characters are not allowed: U+{code_point:X}.')
@ -177,6 +196,13 @@ class ProhibitSurrogateCharactersValidator:
message = self.message.format(code_point=ord(surrogate_character))
raise ValidationError(message, code=self.code)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.code == other.code
)
class BaseUniqueForValidator:
message = None
@ -230,6 +256,17 @@ class BaseUniqueForValidator:
self.field: message
}, code='unique')
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.missing_message == other.missing_message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.field == other.field
and self.date_field == other.date_field
)
def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__,

View File

@ -81,8 +81,10 @@ class URLPathVersioning(BaseVersioning):
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version
kwargs = {
self.version_param: request.version,
**(kwargs or {})
}
return super().reverse(
viewname, args, kwargs, request, format, **extra
@ -117,15 +119,16 @@ class NamespaceVersioning(BaseVersioning):
def determine_version(self, request, *args, **kwargs):
resolver_match = getattr(request, 'resolver_match', None)
if resolver_match is None or not resolver_match.namespace:
return self.default_version
if resolver_match is not None and resolver_match.namespace:
# Allow for possibly nested namespaces.
possible_versions = resolver_match.namespace.split(':')
for version in possible_versions:
if self.is_allowed_version(version):
return version
# Allow for possibly nested namespaces.
possible_versions = resolver_match.namespace.split(':')
for version in possible_versions:
if self.is_allowed_version(version):
return version
raise exceptions.NotFound(self.invalid_version_message)
if not self.is_allowed_version(self.default_version):
raise exceptions.NotFound(self.invalid_version_message)
return self.default_version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None:

View File

@ -16,7 +16,6 @@ automatically.
router.register(r'users', UserViewSet, 'user')
urlpatterns = router.urls
"""
from collections import OrderedDict
from functools import update_wrapper
from inspect import getmembers
@ -183,7 +182,7 @@ class ViewSetMixin:
This method will noop if `detail` was not provided as a view initkwarg.
"""
action_urls = OrderedDict()
action_urls = {}
# exit early if `detail` has not been provided
if self.detail is None:

View File

@ -3,6 +3,8 @@ license_files = LICENSE.md
[tool:pytest]
addopts=--tb=short --strict-markers -ra
testspath = tests
filterwarnings = ignore:CoreAPI compatibility is deprecated*:rest_framework.RemovedInDRF317Warning
[flake8]
ignore = E501,W503,W504

View File

@ -37,7 +37,8 @@ an older version of Django REST Framework:
def read(f):
return open(f, 'r', encoding='utf-8').read()
with open(f, 'r', encoding='utf-8') as file:
return file.read()
def get_version(package):
@ -82,7 +83,7 @@ setup(
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=find_packages(exclude=['tests*']),
include_package_data=True,
install_requires=["django>=3.0", "pytz"],
install_requires=["django>=3.0", 'backports.zoneinfo;python_version<"3.9"'],
python_requires=">=3.6",
zip_safe=False,
classifiers=[

View File

@ -7,16 +7,24 @@ from django.test import TestCase, override_settings
from django.urls import include, path
from rest_framework import (
filters, generics, pagination, permissions, serializers
RemovedInDRF317Warning, filters, generics, pagination, permissions,
serializers
)
from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import action, api_view, schema
from rest_framework.filters import (
BaseFilterBackend, OrderingFilter, SearchFilter
)
from rest_framework.pagination import (
BasePagination, CursorPagination, LimitOffsetPagination,
PageNumberPagination
)
from rest_framework.request import Request
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
)
from rest_framework.schemas.coreapi import field_to_schema
from rest_framework.schemas.coreapi import field_to_schema, is_enabled
from rest_framework.schemas.generators import EndpointEnumerator
from rest_framework.schemas.utils import is_list_view
from rest_framework.test import APIClient, APIRequestFactory
@ -1433,3 +1441,46 @@ def test_schema_handles_exception():
response.render()
assert response.status_code == 403
assert b"You do not have permission to perform this action." in response.content
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_coreapi_deprecation():
with pytest.warns(RemovedInDRF317Warning):
SchemaGenerator()
with pytest.warns(RemovedInDRF317Warning):
AutoSchema()
with pytest.warns(RemovedInDRF317Warning):
ManualSchema({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = OrderingFilter()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = BaseFilterBackend()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = SearchFilter()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = BasePagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = PageNumberPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = LimitOffsetPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = CursorPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
is_enabled()

View File

@ -1162,6 +1162,31 @@ class TestGenerator(TestCase):
assert b'"openapi": "' in ret
assert b'"default": "0.0"' in ret
def test_schema_rendering_to_yaml(self):
patterns = [
path('example/', views.ExampleGenericAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
ret = OpenAPIRenderer().render(schema)
assert b"openapi: " in ret
assert b"default: '0.0'" in ret
def test_schema_rendering_timedelta_to_yaml_with_validator(self):
patterns = [
path('example/', views.ExampleValidatedAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
ret = OpenAPIRenderer().render(schema)
assert b"openapi: " in ret
assert b"duration:\n type: string\n minimum: \'10.0\'\n" in ret
def test_schema_with_no_paths(self):
patterns = []
generator = SchemaGenerator(patterns=patterns)

View File

@ -134,6 +134,11 @@ class ExampleValidatedSerializer(serializers.Serializer):
ip4 = serializers.IPAddressField(protocol='ipv4')
ip6 = serializers.IPAddressField(protocol='ipv6')
ip = serializers.IPAddressField()
duration = serializers.DurationField(
validators=(
MinValueValidator(timedelta(seconds=10)),
)
)
class ExampleValidatedAPIView(generics.GenericAPIView):

View File

@ -5,10 +5,18 @@ import re
import sys
import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal
from enum import auto
from unittest.mock import patch
import pytest
import pytz
try:
import pytz
except ImportError:
pytz = None
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db.models import IntegerChoices, TextChoices
from django.http import QueryDict
from django.test import TestCase, override_settings
from django.utils.timezone import activate, deactivate, override
@ -21,6 +29,11 @@ from rest_framework.fields import (
)
from tests.models import UUIDForeignKeyTarget
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo
utc = datetime.timezone.utc
# Tests for helper functions.
@ -651,7 +664,7 @@ class FieldValues:
"""
Base class for testing valid and invalid input values.
"""
def test_valid_inputs(self):
def test_valid_inputs(self, *args):
"""
Ensure that valid values return the expected validated data.
"""
@ -659,7 +672,7 @@ class FieldValues:
assert self.field.run_validation(input_value) == expected_output, \
'input value: {}'.format(repr(input_value))
def test_invalid_inputs(self):
def test_invalid_inputs(self, *args):
"""
Ensure that invalid values raise the expected validation error.
"""
@ -669,7 +682,7 @@ class FieldValues:
assert exc_info.value.detail == expected_failure, \
'input value: {}'.format(repr(input_value))
def test_outputs(self):
def test_outputs(self, *args):
for output_value, expected_output in get_items(self.outputs):
assert self.field.to_representation(output_value) == expected_output, \
'output value: {}'.format(repr(output_value))
@ -682,8 +695,24 @@ class TestBooleanField(FieldValues):
Valid and invalid values for `BooleanField`.
"""
valid_inputs = {
'True': True,
'TRUE': True,
'tRuE': True,
't': True,
'T': True,
'true': True,
'on': True,
'ON': True,
'oN': True,
'False': False,
'FALSE': False,
'fALse': False,
'f': False,
'F': False,
'false': False,
'off': False,
'OFF': False,
'oFf': False,
'1': True,
'0': False,
1: True,
@ -696,8 +725,24 @@ class TestBooleanField(FieldValues):
None: ['This field may not be null.']
}
outputs = {
'True': True,
'TRUE': True,
'tRuE': True,
't': True,
'T': True,
'true': True,
'on': True,
'ON': True,
'oN': True,
'False': False,
'FALSE': False,
'fALse': False,
'f': False,
'F': False,
'false': False,
'off': False,
'OFF': False,
'oFf': False,
'1': True,
'0': False,
1: True,
@ -1208,6 +1253,17 @@ class TestMinMaxDecimalField(FieldValues):
min_value=10, max_value=20
)
def test_warning_when_not_decimal_types(self, caplog):
import logging
serializers.DecimalField(
max_digits=3, decimal_places=1,
min_value=10, max_value=20
)
assert caplog.record_tuples == [
("rest_framework.fields", logging.WARNING, "max_value in DecimalField should be Decimal type."),
("rest_framework.fields", logging.WARNING, "min_value in DecimalField should be Decimal type.")
]
class TestAllowEmptyStrDecimalFieldWithValidators(FieldValues):
"""
@ -1505,12 +1561,12 @@ class TestTZWithDateTimeField(FieldValues):
@classmethod
def setup_class(cls):
# use class setup method, as class-level attribute will still be evaluated even if test is skipped
kolkata = pytz.timezone('Asia/Kolkata')
kolkata = ZoneInfo('Asia/Kolkata')
cls.valid_inputs = {
'2016-12-19T10:00:00': kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
'2016-12-19T10:00:00+05:30': kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
datetime.datetime(2016, 12, 19, 10): kolkata.localize(datetime.datetime(2016, 12, 19, 10)),
'2016-12-19T10:00:00': datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
'2016-12-19T10:00:00+05:30': datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
datetime.datetime(2016, 12, 19, 10): datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
}
cls.invalid_inputs = {}
cls.outputs = {
@ -1529,7 +1585,7 @@ class TestDefaultTZDateTimeField(TestCase):
@classmethod
def setup_class(cls):
cls.field = serializers.DateTimeField()
cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.kolkata = ZoneInfo('Asia/Kolkata')
def assertUTC(self, tzinfo):
"""
@ -1551,18 +1607,17 @@ class TestDefaultTZDateTimeField(TestCase):
self.assertUTC(self.field.default_timezone())
@pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod
def setup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata')
cls.kolkata = ZoneInfo('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=ZoneInfo("UTC"))
with override(self.kolkata):
rendered_date = field.to_representation(dt)
@ -1572,6 +1627,33 @@ class TestCustomTimezoneForDateTimeField(TestCase):
assert rendered_date == rendered_date_in_timezone
@pytest.mark.skipif(pytz is None, reason="As Django 4.0 has deprecated pytz, this test should eventually be able to get removed.")
class TestPytzNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
"""
Invalid values for `DateTimeField` with datetime in DST shift (non-existing or ambiguous) and timezone with DST.
Timezone America/New_York has DST shift from 2017-03-12T02:00:00 to 2017-03-12T03:00:00 and
from 2017-11-05T02:00:00 to 2017-11-05T01:00:00 in 2017.
"""
valid_inputs = {}
invalid_inputs = {
'2017-03-12T02:30:00': ['Invalid datetime for the timezone "America/New_York".'],
'2017-11-05T01:30:00': ['Invalid datetime for the timezone "America/New_York".']
}
outputs = {}
if pytz:
class MockTimezone(pytz.BaseTzInfo):
@staticmethod
def localize(value, is_dst):
raise pytz.InvalidTimeError()
def __str__(self):
return 'America/New_York'
field = serializers.DateTimeField(default_timezone=MockTimezone())
@patch('rest_framework.utils.timezone.datetime_ambiguous', return_value=True)
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
"""
Invalid values for `DateTimeField` with datetime in DST shift (non-existing or ambiguous) and timezone with DST.
@ -1585,15 +1667,11 @@ class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
}
outputs = {}
class MockTimezone(pytz.BaseTzInfo):
@staticmethod
def localize(value, is_dst):
raise pytz.InvalidTimeError()
class MockZoneInfoTimezone(datetime.tzinfo):
def __str__(self):
return 'America/New_York'
field = serializers.DateTimeField(default_timezone=MockTimezone())
field = serializers.DateTimeField(default_timezone=MockZoneInfoTimezone())
class TestTimeField(FieldValues):
@ -1797,6 +1875,54 @@ class TestChoiceField(FieldValues):
field.run_validation(2)
assert exc_info.value.detail == ['"2" is not a valid choice.']
def test_integer_choices(self):
class ChoiceCase(IntegerChoices):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "1"),
(ChoiceCase.second, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
choices = [
(ChoiceCase.first.value, "1"),
(ChoiceCase.second.value, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
def test_text_choices(self):
class ChoiceCase(TextChoices):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "first"),
(ChoiceCase.second, "second")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(ChoiceCase.first) == "first"
assert field.run_validation("first") == "first"
choices = [
(ChoiceCase.first.value, "first"),
(ChoiceCase.second.value, "second")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(ChoiceCase.first) == "first"
assert field.run_validation("first") == "first"
class TestChoiceFieldWithType(FieldValues):
"""

View File

@ -324,6 +324,13 @@ class TestSimpleMetadataFieldInfo(TestCase):
)
assert 'choices' not in field_info
def test_decimal_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.DecimalField(max_digits=18, decimal_places=4))
assert field_info['type'] == 'decimal'
assert field_info['max_digits'] == 18
assert field_info['decimal_places'] == 4
class TestModelSerializerMetadata(TestCase):
def test_read_only_primary_key_related_field(self):

View File

@ -10,7 +10,6 @@ import decimal
import json # noqa
import sys
import tempfile
from collections import OrderedDict
import django
import pytest
@ -762,7 +761,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__'
serializer = TestSerializer()
expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')])
expected = {1: 'Red Color', 2: 'Yellow Color', 3: 'Green Color'}
self.assertEqual(serializer.fields['color'].choices, expected)
def test_custom_display_value(self):
@ -778,7 +777,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__'
serializer = TestSerializer()
expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')])
expected = {1: 'My Red Color', 2: 'My Yellow Color', 3: 'My Green Color'}
self.assertEqual(serializer.fields['color'].choices, expected)

View File

@ -632,6 +632,24 @@ class CursorPaginationTestsMixin:
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
def test_use_with_ordering_filter_without_ordering_default_value(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created']
request = Request(factory.get('/'))
ordering = self.pagination.get_ordering(request, [], MockView())
# it gets the value of `ordering` provided by CursorPagination
assert ordering == ('created',)
request = Request(factory.get('/', {'ordering': 'username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',)
request = Request(factory.get('/', {'ordering': 'invalid'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/')
@ -951,17 +969,24 @@ class TestCursorPagination(CursorPaginationTestsMixin):
def __init__(self, items):
self.items = items
def filter(self, created__gt=None, created__lt=None):
def filter(self, q):
q_args = dict(q.deconstruct()[1])
if not q_args:
# django 3.0.x artifact
q_args = dict(q.deconstruct()[2])
created__gt = q_args.get('created__gt')
created__lt = q_args.get('created__lt')
if created__gt is not None:
return MockQuerySet([
item for item in self.items
if item.created > int(created__gt)
if item.created is None or item.created > int(created__gt)
])
assert created__lt is not None
return MockQuerySet([
item for item in self.items
if item.created < int(created__lt)
if item.created is None or item.created < int(created__lt)
])
def order_by(self, *ordering):
@ -1080,6 +1105,127 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
return (previous, current, next, previous_url, next_url)
class NullableCursorPaginationModel(models.Model):
created = models.IntegerField(null=True)
class TestCursorPaginationWithNulls(TestCase):
"""
Unit tests for `pagination.CursorPagination` with ordering on a nullable field.
"""
def setUp(self):
class ExamplePagination(pagination.CursorPagination):
page_size = 1
ordering = 'created'
self.pagination = ExamplePagination()
data = [
None, None, 3, 4
]
for idx in data:
NullableCursorPaginationModel.objects.create(created=idx)
self.queryset = NullableCursorPaginationModel.objects.all()
get_pages = TestCursorPagination.get_pages
def test_ascending(self):
"""Test paginating one row at a time, current should go 1, 2, 3, 4, 3, 2, 1."""
(previous, current, next, previous_url, next_url) = self.get_pages('/')
assert previous is None
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None]
assert current == [None]
assert next == [3]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [3] # [None] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L789
assert current == [3]
assert next == [4]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [3]
assert current == [4]
assert next is None
assert next_url is None
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [3]
assert next == [4]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [None]
assert next == [None] # [3] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous is None
assert current == [None]
assert next == [None]
def test_descending(self):
"""Test paginating one row at a time, current should go 4, 3, 2, 1, 2, 3, 4."""
self.pagination.ordering = ('-created',)
(previous, current, next, previous_url, next_url) = self.get_pages('/')
assert previous is None
assert current == [4]
assert next == [3]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None] # [4] paging artifact
assert current == [3]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None] # [3] paging artifact
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None]
assert current == [None]
assert next is None
assert next_url is None
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [3]
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [3]
assert next == [3] # [4] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
# skip back artifact
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous is None
assert current == [4]
assert next == [3]
def test_get_displayed_page_numbers():
"""
Test our contextual page display function.

View File

@ -342,6 +342,142 @@ class TestSlugRelatedField(APISimpleTestCase):
field.to_internal_value(self.instance.name)
class TestNestedSlugRelatedField(APISimpleTestCase):
def setUp(self):
self.queryset = MockQueryset([
MockObject(
pk=1, name='foo', nested=MockObject(
pk=2, name='bar', nested=MockObject(
pk=7, name="foobar"
)
)
),
MockObject(
pk=3, name='hello', nested=MockObject(
pk=4, name='world', nested=MockObject(
pk=8, name="helloworld"
)
)
),
MockObject(
pk=5, name='harry', nested=MockObject(
pk=6, name='potter', nested=MockObject(
pk=9, name="harrypotter"
)
)
)
])
self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField(
slug_field='name', queryset=self.queryset
)
self.nested_field = serializers.SlugRelatedField(
slug_field='nested__name', queryset=self.queryset
)
self.nested_nested_field = serializers.SlugRelatedField(
slug_field='nested__nested__name', queryset=self.queryset
)
# testing nested inside nested relations
def test_slug_related_nested_nested_lookup_exists(self):
instance = self.nested_nested_field.to_internal_value(
self.instance.nested.nested.name
)
assert instance is self.instance
def test_slug_related_nested_nested_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_nested_field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == \
'Object with nested__nested__name=doesnotexist does not exist.'
def test_slug_related_nested_nested_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_nested_field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_nested_nested_representation(self):
representation =\
self.nested_nested_field.to_representation(self.instance)
assert representation == self.instance.nested.nested.name
def test_nested_nested_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='nested__nested__name')
field.to_internal_value(self.instance.nested.nested.name)
# testing nested relations
def test_slug_related_nested_lookup_exists(self):
instance = \
self.nested_field.to_internal_value(self.instance.nested.name)
assert instance is self.instance
def test_slug_related_nested_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == 'Object with nested__name=doesnotexist does not exist.'
def test_slug_related_nested_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_nested_representation(self):
representation = self.nested_field.to_representation(self.instance)
assert representation == self.instance.nested.name
def test_nested_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='nested__name')
field.to_internal_value(self.instance.nested.name)
# testing non-nested relations
def test_slug_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.name)
assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.name
def test_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='name')
field.to_internal_value(self.instance.name)
class TestManyRelatedField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')

View File

@ -1,5 +1,4 @@
import re
from collections import OrderedDict
from collections.abc import MutableMapping
import pytest
@ -457,12 +456,12 @@ class CacheRenderTest(TestCase):
class TestJSONIndentationStyles:
def test_indented(self):
renderer = JSONRenderer()
data = OrderedDict([('a', 1), ('b', 2)])
data = {"a": 1, "b": 2}
assert renderer.render(data) == b'{"a":1,"b":2}'
def test_compact(self):
renderer = JSONRenderer()
data = OrderedDict([('a', 1), ('b', 2)])
data = {"a": 1, "b": 2}
context = {'indent': 4}
assert (
renderer.render(data, renderer_context=context) ==
@ -472,7 +471,7 @@ class TestJSONIndentationStyles:
def test_long_form(self):
renderer = JSONRenderer()
renderer.compact = False
data = OrderedDict([('a', 1), ('b', 2)])
data = {"a": 1, "b": 2}
assert renderer.render(data) == b'{"a": 1, "b": 2}'
@ -634,6 +633,9 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
class AuthExampleViewSet(ExampleViewSet):
permission_classes = [permissions.IsAuthenticated]
class SimpleSerializer(serializers.Serializer):
name = serializers.CharField()
router = SimpleRouter()
router.register('examples', ExampleViewSet, basename='example')
router.register('auth-examples', AuthExampleViewSet, basename='auth-example')
@ -641,6 +643,62 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
def setUp(self):
self.renderer = BrowsableAPIRenderer()
self.renderer.accepted_media_type = ''
self.renderer.renderer_context = {}
def test_render_form_for_serializer(self):
with self.subTest('Serializer'):
serializer = BrowsableAPIRendererTests.SimpleSerializer(data={'name': 'Name'})
form = self.renderer.render_form_for_serializer(serializer)
assert isinstance(form, str), 'Must return form for serializer'
with self.subTest('ListSerializer'):
list_serializer = BrowsableAPIRendererTests.SimpleSerializer(data=[{'name': 'Name'}], many=True)
form = self.renderer.render_form_for_serializer(list_serializer)
assert form is None, 'Must not return form for list serializer'
def test_get_raw_data_form(self):
with self.subTest('Serializer'):
class DummyGenericViewsetLike(APIView):
def get_serializer(self, **kwargs):
return BrowsableAPIRendererTests.SimpleSerializer(**kwargs)
def get(self, request):
response = Response()
response.view = self
return response
post = get
view = DummyGenericViewsetLike.as_view()
_request = APIRequestFactory().get('/')
request = Request(_request)
response = view(_request)
view = response.view
raw_data_form = self.renderer.get_raw_data_form({'name': 'Name'}, view, 'POST', request)
assert raw_data_form['_content'].initial == '{\n "name": ""\n}'
with self.subTest('ListSerializer'):
class DummyGenericViewsetLike(APIView):
def get_serializer(self, **kwargs):
return BrowsableAPIRendererTests.SimpleSerializer(many=True, **kwargs) # returns ListSerializer
def get(self, request):
response = Response()
response.view = self
return response
post = get
view = DummyGenericViewsetLike.as_view()
_request = APIRequestFactory().get('/')
request = Request(_request)
response = view(_request)
view = response.view
raw_data_form = self.renderer.get_raw_data_form([{'name': 'Name'}], view, 'POST', request)
assert raw_data_form['_content'].initial == '[\n {\n "name": ""\n }\n]'
def test_get_description_returns_empty_string_for_401_and_403_statuses(self):
assert self.renderer.get_description({}, status_code=401) == ''

View File

@ -2,6 +2,7 @@ import inspect
import pickle
import re
import sys
import unittest
from collections import ChainMap
from collections.abc import Mapping
@ -764,3 +765,84 @@ class Test8301Regression:
assert (s.data | {}).__class__ == s.data.__class__
assert ({} | s.data).__class__ == s.data.__class__
class TestSetValueMethod:
# Serializer.set_value() modifies the first parameter in-place.
s = serializers.Serializer()
def test_no_keys(self):
ret = {'a': 1}
self.s.set_value(ret, [], {'b': 2})
assert ret == {'a': 1, 'b': 2}
def test_one_key(self):
ret = {'a': 1}
self.s.set_value(ret, ['x'], 2)
assert ret == {'a': 1, 'x': 2}
def test_nested_key(self):
ret = {'a': 1}
self.s.set_value(ret, ['x', 'y'], 2)
assert ret == {'a': 1, 'x': {'y': 2}}
class MyClass(models.Model):
name = models.CharField(max_length=100)
value = models.CharField(max_length=100, blank=True)
app_label = "test"
@property
def is_valid(self):
return self.name == 'valid'
class MyClassSerializer(serializers.ModelSerializer):
class Meta:
model = MyClass
fields = ('id', 'name', 'value')
def validate_value(self, value):
if value and not self.instance.is_valid:
raise serializers.ValidationError(
'Status cannot be set for invalid instance')
return value
class TestMultipleObjectsValidation(unittest.TestCase):
def setUp(self):
self.objs = [
MyClass(name='valid'),
MyClass(name='invalid'),
MyClass(name='other'),
]
def test_multiple_objects_are_validated_separately(self):
serializer = MyClassSerializer(
data=[{'value': 'set', 'id': instance.id} for instance in
self.objs],
instance=self.objs,
many=True,
partial=True,
)
assert not serializer.is_valid()
assert serializer.errors == [
{},
{'value': ['Status cannot be set for invalid instance']},
{'value': ['Status cannot be set for invalid instance']}
]
def test_exception_raised_when_data_and_instance_length_different(self):
with self.assertRaises(AssertionError):
MyClassSerializer(
data=[{'value': 'set', 'id': instance.id} for instance in
self.objs],
instance=self.objs[:-1],
many=True,
partial=True,
)

View File

@ -1,4 +1,5 @@
import datetime
from unittest.mock import MagicMock
import pytest
from django.db import DataError, models
@ -787,3 +788,13 @@ class ValidatorsTests(TestCase):
validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name=''
)
def test_equality_operator(self):
mock_queryset = MagicMock()
validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
assert validator == validator2
validator2.date_field = "bar2"
assert validator != validator2

View File

@ -152,6 +152,8 @@ class TestURLReversing(URLPatternsTestCase, APITestCase):
path('v1/', include((included, 'v1'), namespace='v1')),
path('another/', dummy_view, name='another'),
re_path(r'^(?P<version>[v1|v2]+)/another/$', dummy_view, name='another'),
re_path(r'^(?P<foo>.+)/unversioned/$', dummy_view, name='unversioned'),
]
def test_reverse_unversioned(self):
@ -198,6 +200,14 @@ class TestURLReversing(URLPatternsTestCase, APITestCase):
response = view(request)
assert response.data == {'url': 'http://testserver/another/'}
# Test fallback when kwargs is not None
request = factory.get('/v1/endpoint/')
request.versioning_scheme = scheme()
request.version = 'v1'
reversed_url = reverse('unversioned', request=request, kwargs={'foo': 'bar'})
assert reversed_url == 'http://testserver/bar/unversioned/'
def test_reverse_namespace_versioning(self):
class FakeResolverMatch(ResolverMatch):
namespace = 'v1'
@ -262,7 +272,7 @@ class TestInvalidVersion:
assert response.status_code == status.HTTP_404_NOT_FOUND
class TestAllowedAndDefaultVersion:
class TestAcceptHeaderAllowedAndDefaultVersion:
def test_missing_without_default(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
@ -308,6 +318,97 @@ class TestAllowedAndDefaultVersion:
assert response.data == {'version': 'v2'}
class TestNamespaceAllowedAndDefaultVersion:
def test_no_namespace_without_default(self):
class FakeResolverMatch:
namespace = None
scheme = versioning.NamespaceVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_no_namespace_with_default(self):
class FakeResolverMatch:
namespace = None
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
def test_no_match_without_default(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_no_match_with_default(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
def test_with_default(self):
class FakeResolverMatch:
namespace = 'v1'
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v1'}
def test_no_match_without_default_but_none_allowed(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedWithNoneVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': None}
def test_no_match_with_default_and_none_allowed(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedWithNoneAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase):
included = [
path('namespaced/<int:pk>/', dummy_pk_view, name='namespaced'),

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from functools import wraps
import pytest
@ -261,11 +260,11 @@ class GetExtraActionUrlMapTests(TestCase):
response = self.client.get('/api/actions/')
view = response.view
expected = OrderedDict([
('Custom list action', 'http://testserver/api/actions/custom_list_action/'),
('List action', 'http://testserver/api/actions/list_action/'),
('Wrapped list action', 'http://testserver/api/actions/wrapped_list_action/'),
])
expected = {
'Custom list action': 'http://testserver/api/actions/custom_list_action/',
'List action': 'http://testserver/api/actions/list_action/',
'Wrapped list action': 'http://testserver/api/actions/wrapped_list_action/',
}
self.assertEqual(view.get_extra_action_url_map(), expected)
@ -273,28 +272,28 @@ class GetExtraActionUrlMapTests(TestCase):
response = self.client.get('/api/actions/1/')
view = response.view
expected = OrderedDict([
('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'),
('Detail action', 'http://testserver/api/actions/1/detail_action/'),
('Wrapped detail action', 'http://testserver/api/actions/1/wrapped_detail_action/'),
expected = {
'Custom detail action': 'http://testserver/api/actions/1/custom_detail_action/',
'Detail action': 'http://testserver/api/actions/1/detail_action/',
'Wrapped detail action': 'http://testserver/api/actions/1/wrapped_detail_action/',
# "Unresolvable detail action" excluded, since it's not resolvable
])
}
self.assertEqual(view.get_extra_action_url_map(), expected)
def test_uninitialized_view(self):
self.assertEqual(ActionViewSet().get_extra_action_url_map(), OrderedDict())
self.assertEqual(ActionViewSet().get_extra_action_url_map(), {})
def test_action_names(self):
# Action 'name' and 'suffix' kwargs should be respected
response = self.client.get('/api/names/1/')
view = response.view
expected = OrderedDict([
('Custom Name', 'http://testserver/api/names/1/named_action/'),
('Action Names Custom Suffix', 'http://testserver/api/names/1/suffixed_action/'),
('Unnamed action', 'http://testserver/api/names/1/unnamed_action/'),
])
expected = {
'Custom Name': 'http://testserver/api/names/1/named_action/',
'Action Names Custom Suffix': 'http://testserver/api/names/1/suffixed_action/',
'Unnamed action': 'http://testserver/api/names/1/unnamed_action/',
}
self.assertEqual(view.get_extra_action_url_map(), expected)

View File

@ -1,3 +1,5 @@
from operator import attrgetter
from django.core.exceptions import ObjectDoesNotExist
from django.urls import NoReverseMatch
@ -26,7 +28,7 @@ class MockQueryset:
def get(self, **lookup):
for item in self.items:
if all([
getattr(item, key, None) == value
attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items()
]):
return item
@ -39,6 +41,7 @@ class BadType:
will raise a `TypeError`, as occurs in Django when making
queryset lookups with an incorrect type for the lookup value.
"""
def __eq__(self):
raise TypeError()

View File

@ -21,7 +21,7 @@ deps =
django32: Django>=3.2,<4.0
django40: Django>=4.0,<4.1
django41: Django>=4.1,<4.2
django42: Django>=4.2b1,<5.0
django42: Django>=4.2,<5.0
djangomain: https://github.com/django/django/archive/main.tar.gz
-rrequirements/requirements-testing.txt
-rrequirements/requirements-optionals.txt