Merge branch 'master' into contextvars-for-request-context

This commit is contained in:
Daler 2023-06-21 14:48:41 +05:00 committed by GitHub
commit 6a7df3bb09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
53 changed files with 1204 additions and 262 deletions

View File

@ -8,17 +8,17 @@ on:
jobs: jobs:
pre-commit: pre-commit:
runs-on: ubuntu-20.04 runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- uses: actions/setup-python@v2 - uses: actions/setup-python@v4
with: with:
python-version: 3.9 python-version: "3.10"
- uses: pre-commit/action@v2.0.0 - uses: pre-commit/action@v3.0.0
with: with:
token: ${{ secrets.GITHUB_TOKEN }} token: ${{ secrets.GITHUB_TOKEN }}

View File

@ -55,7 +55,7 @@ There is a live example API for testing purposes, [available here][sandbox].
# Requirements # Requirements
* Python 3.6+ * Python 3.6+
* Django 4.1, 4.0, 3.2, 3.1, 3.0 * Django 4.2, 4.1, 4.0, 3.2, 3.1, 3.0
We **highly recommend** and only officially support the latest patch release of We **highly recommend** and only officially support the latest patch release of
each Python and Django series. each Python and Django series.

View File

@ -165,7 +165,7 @@ This permission is suitable if you want your API to only be accessible to a subs
## IsAuthenticatedOrReadOnly ## IsAuthenticatedOrReadOnly
The `IsAuthenticatedOrReadOnly` will allow authenticated users to perform any request. Requests for unauthorised users will only be permitted if the request method is one of the "safe" methods; `GET`, `HEAD` or `OPTIONS`. The `IsAuthenticatedOrReadOnly` will allow authenticated users to perform any request. Requests for unauthenticated users will only be permitted if the request method is one of the "safe" methods; `GET`, `HEAD` or `OPTIONS`.
This permission is suitable if you want to your API to allow read permissions to anonymous users, and only allow write permissions to authenticated users. This permission is suitable if you want to your API to allow read permissions to anonymous users, and only allow write permissions to authenticated users.

View File

@ -54,5 +54,5 @@ As with the `reverse` function, you should **include the request as a keyword ar
api_root = reverse_lazy('api-root', request=request) api_root = reverse_lazy('api-root', request=request)
[cite]: https://www.ics.uci.edu/~fielding/pubs/dissertation/rest_arch_style.htm#sec_5_1_5 [cite]: https://www.ics.uci.edu/~fielding/pubs/dissertation/rest_arch_style.htm#sec_5_1_5
[reverse]: https://docs.djangoproject.com/en/stable/topics/http/urls/#reverse [reverse]: https://docs.djangoproject.com/en/stable/ref/urlresolvers/#reverse
[reverse-lazy]: https://docs.djangoproject.com/en/stable/topics/http/urls/#reverse-lazy [reverse-lazy]: https://docs.djangoproject.com/en/stable/ref/urlresolvers/#reverse-lazy

View File

@ -53,7 +53,7 @@ If we open up the Django shell using `manage.py shell` we can now
The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field. The interesting bit here is the `reference` field. We can see that the uniqueness constraint is being explicitly enforced by a validator on the serializer field.
Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. Because of this more explicit style REST framework includes a few validator classes that are not available in core Django. These classes are detailed below. REST framework validators, like their Django counterparts, implement the `__eq__` method, allowing you to compare instances for equality.
--- ---
@ -295,13 +295,14 @@ To write a class-based validator, use the `__call__` method. Class-based validat
In some advanced cases you might want a validator to be passed the serializer In some advanced cases you might want a validator to be passed the serializer
field it is being used with as additional context. You can do so by setting field it is being used with as additional context. You can do so by setting
a `requires_context = True` attribute on the validator. The `__call__` method a `requires_context = True` attribute on the validator class. The `__call__` method
will then be called with the `serializer_field` will then be called with the `serializer_field`
or `serializer` as an additional argument. or `serializer` as an additional argument.
requires_context = True class MultipleOf:
requires_context = True
def __call__(self, value, serializer_field): def __call__(self, value, serializer_field):
... ...
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/ [cite]: https://docs.djangoproject.com/en/stable/ref/validators/

View File

@ -143,6 +143,16 @@ class PublisherSearchView(generics.ListAPIView):
--- ---
## Deprecations
### `serializers.NullBooleanField`
`serializers.NullBooleanField` is now pending deprecation, and will be removed in 3.14.
Instead use `serializers.BooleanField` field and set `allow_null=True` which does the same thing.
---
## Funding ## Funding
REST framework is a *collaboratively funded project*. If you use REST framework is a *collaboratively funded project*. If you use

View File

@ -60,3 +60,13 @@ See Pull Request [#7522](https://github.com/encode/django-rest-framework/pull/75
## Minor fixes and improvements ## Minor fixes and improvements
There are a number of minor fixes and improvements in this release. See the [release notes](release-notes.md) page for a complete listing. There are a number of minor fixes and improvements in this release. See the [release notes](release-notes.md) page for a complete listing.
---
## Deprecations
### `serializers.NullBooleanField`
`serializers.NullBooleanField` was moved to pending deprecation in 3.12, and deprecated in 3.13. It has now been removed from the core framework.
Instead use `serializers.BooleanField` field and set `allow_null=True` which does the same thing.

View File

@ -157,6 +157,7 @@ Date: 28th September 2020
* Fix `PrimaryKeyRelatedField` and `HyperlinkedRelatedField` when source field is actually a property. [#7142] * Fix `PrimaryKeyRelatedField` and `HyperlinkedRelatedField` when source field is actually a property. [#7142]
* `Token.generate_key` is now a class method. [#7502] * `Token.generate_key` is now a class method. [#7502]
* `@action` warns if method is wrapped in a decorator that does not preserve information using `@functools.wraps`. [#7098] * `@action` warns if method is wrapped in a decorator that does not preserve information using `@functools.wraps`. [#7098]
* Deprecate `serializers.NullBooleanField` in favour of `serializers.BooleanField` with `allow_null=True` [#7122]
--- ---

View File

@ -86,7 +86,7 @@ continued development by **[signing up for a paid plan][funding]**.
REST framework requires the following: REST framework requires the following:
* Python (3.6, 3.7, 3.8, 3.9, 3.10, 3.11) * Python (3.6, 3.7, 3.8, 3.9, 3.10, 3.11)
* Django (2.2, 3.0, 3.1, 3.2, 4.0, 4.1) * Django (3.0, 3.1, 3.2, 4.0, 4.1)
We **highly recommend** and only officially support the latest patch release of We **highly recommend** and only officially support the latest patch release of
each Python and Django series. each Python and Django series.

View File

@ -1,8 +1,8 @@
# Wheel for PyPI installs. # Wheel for PyPI installs.
wheel>=0.35.1,<0.36 wheel>=0.36.2,<0.40.0
# Twine for secured PyPI uploads. # Twine for secured PyPI uploads.
twine>=3.2.0,<3.3 twine>=3.4.2,<4.0.2
# Transifex client for managing translation resources. # Transifex client for managing translation resources.
transifex-client transifex-client

View File

@ -13,7 +13,7 @@ __title__ = 'Django REST framework'
__version__ = '3.14.0' __version__ = '3.14.0'
__author__ = 'Tom Christie' __author__ = 'Tom Christie'
__license__ = 'BSD 3-Clause' __license__ = 'BSD 3-Clause'
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd' __copyright__ = 'Copyright 2011-2023 Encode OSS Ltd'
# Version synonym # Version synonym
VERSION = __version__ VERSION = __version__
@ -31,3 +31,7 @@ if django.VERSION < (3, 2):
class RemovedInDRF315Warning(DeprecationWarning): class RemovedInDRF315Warning(DeprecationWarning):
pass pass
class RemovedInDRF317Warning(PendingDeprecationWarning):
pass

View File

@ -4,6 +4,7 @@ from django.contrib.admin.views.main import ChangeList
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext_lazy as _
from rest_framework.authtoken.models import Token, TokenProxy from rest_framework.authtoken.models import Token, TokenProxy
@ -23,6 +24,8 @@ class TokenChangeList(ChangeList):
class TokenAdmin(admin.ModelAdmin): class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created') list_display = ('key', 'user', 'created')
fields = ('user',) fields = ('user',)
search_fields = ('user__username',)
search_help_text = _('Username')
ordering = ('-created',) ordering = ('-created',)
actions = None # Actions not compatible with mapped IDs. actions = None # Actions not compatible with mapped IDs.
autocomplete_fields = ("user",) autocomplete_fields = ("user",)

View File

@ -4,9 +4,9 @@ import datetime
import decimal import decimal
import functools import functools
import inspect import inspect
import logging
import re import re
import uuid import uuid
from collections import OrderedDict
from collections.abc import Mapping from collections.abc import Mapping
from django.conf import settings from django.conf import settings
@ -17,6 +17,7 @@ from django.core.validators import (
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators URLValidator, ip_address_validators
) )
from django.db.models import IntegerChoices, TextChoices
from django.forms import FilePathField as DjangoFilePathField from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField from django.forms import ImageField as DjangoImageField
from django.utils import timezone from django.utils import timezone
@ -28,15 +29,22 @@ from django.utils.encoding import is_protected_type, smart_str
from django.utils.formats import localize_input, sanitize_separators from django.utils.formats import localize_input, sanitize_separators
from django.utils.ipv6 import clean_ipv6_address from django.utils.ipv6 import clean_ipv6_address
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from pytz.exceptions import InvalidTimeError
try:
import pytz
except ImportError:
pytz = None
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, json, representation from rest_framework.utils import html, humanize_datetime, json, representation
from rest_framework.utils.formatting import lazy_format from rest_framework.utils.formatting import lazy_format
from rest_framework.utils.timezone import valid_datetime
from rest_framework.validators import ProhibitSurrogateCharactersValidator from rest_framework.validators import ProhibitSurrogateCharactersValidator
logger = logging.getLogger("rest_framework.fields")
class empty: class empty:
""" """
@ -109,27 +117,6 @@ def get_attribute(instance, attrs):
return instance return instance
def set_value(dictionary, keys, value):
"""
Similar to Python's built in `dictionary[key] = value`,
but takes a list of nested keys instead of a single key.
set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
if not keys:
dictionary.update(value)
return
for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
def to_choices_dict(choices): def to_choices_dict(choices):
""" """
Convert choices into key/value dicts. Convert choices into key/value dicts.
@ -142,7 +129,7 @@ def to_choices_dict(choices):
# choices = [1, 2, 3] # choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
# choices = [('Category', ((1, 'First'), (2, 'Second'))), (3, 'Third')] # choices = [('Category', ((1, 'First'), (2, 'Second'))), (3, 'Third')]
ret = OrderedDict() ret = {}
for choice in choices: for choice in choices:
if not isinstance(choice, (list, tuple)): if not isinstance(choice, (list, tuple)):
# single choice # single choice
@ -165,7 +152,7 @@ def flatten_choices_dict(choices):
flatten_choices_dict({1: '1st', 2: '2nd'}) -> {1: '1st', 2: '2nd'} flatten_choices_dict({1: '1st', 2: '2nd'}) -> {1: '1st', 2: '2nd'}
flatten_choices_dict({'Group': {1: '1st', 2: '2nd'}}) -> {1: '1st', 2: '2nd'} flatten_choices_dict({'Group': {1: '1st', 2: '2nd'}}) -> {1: '1st', 2: '2nd'}
""" """
ret = OrderedDict() ret = {}
for key, value in choices.items(): for key, value in choices.items():
if isinstance(value, dict): if isinstance(value, dict):
# grouped choices (category, sub choices) # grouped choices (category, sub choices)
@ -677,22 +664,27 @@ class BooleanField(Field):
default_empty_html = False default_empty_html = False
initial = False initial = False
TRUE_VALUES = { TRUE_VALUES = {
't', 'T', 't',
'y', 'Y', 'yes', 'Yes', 'YES', 'y',
'true', 'True', 'TRUE', 'yes',
'on', 'On', 'ON', 'true',
'1', 1, 'on',
True '1',
1,
True,
} }
FALSE_VALUES = { FALSE_VALUES = {
'f', 'F', 'f',
'n', 'N', 'no', 'No', 'NO', 'n',
'false', 'False', 'FALSE', 'no',
'off', 'Off', 'OFF', 'false',
'0', 0, 0.0, 'off',
False '0',
0,
0.0,
False,
} }
NULL_VALUES = {'null', 'Null', 'NULL', '', None} NULL_VALUES = {'null', '', None}
def __init__(self, **kwargs): def __init__(self, **kwargs):
if kwargs.get('allow_null', False): if kwargs.get('allow_null', False):
@ -700,22 +692,28 @@ class BooleanField(Field):
self.initial = None self.initial = None
super().__init__(**kwargs) super().__init__(**kwargs)
@staticmethod
def _lower_if_str(value):
if isinstance(value, str):
return value.lower()
return value
def to_internal_value(self, data): def to_internal_value(self, data):
with contextlib.suppress(TypeError): with contextlib.suppress(TypeError):
if data in self.TRUE_VALUES: if self._lower_if_str(data) in self.TRUE_VALUES:
return True return True
elif data in self.FALSE_VALUES: elif self._lower_if_str(data) in self.FALSE_VALUES:
return False return False
elif data in self.NULL_VALUES and self.allow_null: elif self._lower_if_str(data) in self.NULL_VALUES and self.allow_null:
return None return None
self.fail('invalid', input=data) self.fail("invalid", input=data)
def to_representation(self, value): def to_representation(self, value):
if value in self.TRUE_VALUES: if self._lower_if_str(value) in self.TRUE_VALUES:
return True return True
elif value in self.FALSE_VALUES: elif self._lower_if_str(value) in self.FALSE_VALUES:
return False return False
if value in self.NULL_VALUES and self.allow_null: if self._lower_if_str(value) in self.NULL_VALUES and self.allow_null:
return None return None
return bool(value) return bool(value)
@ -989,6 +987,11 @@ class DecimalField(Field):
self.max_value = max_value self.max_value = max_value
self.min_value = min_value self.min_value = min_value
if self.max_value is not None and not isinstance(self.max_value, decimal.Decimal):
logger.warning("max_value in DecimalField should be Decimal type.")
if self.min_value is not None and not isinstance(self.min_value, decimal.Decimal):
logger.warning("min_value in DecimalField should be Decimal type.")
if self.max_digits is not None and self.decimal_places is not None: if self.max_digits is not None and self.decimal_places is not None:
self.max_whole_digits = self.max_digits - self.decimal_places self.max_whole_digits = self.max_digits - self.decimal_places
else: else:
@ -1154,9 +1157,16 @@ class DateTimeField(Field):
except OverflowError: except OverflowError:
self.fail('overflow') self.fail('overflow')
try: try:
return timezone.make_aware(value, field_timezone) dt = timezone.make_aware(value, field_timezone)
except InvalidTimeError: # When the resulting datetime is a ZoneInfo instance, it won't necessarily
self.fail('make_aware', timezone=field_timezone) # throw given an invalid datetime, so we need to specifically check.
if not valid_datetime(dt):
self.fail('make_aware', timezone=field_timezone)
return dt
except Exception as e:
if pytz and isinstance(e, pytz.exceptions.InvalidTimeError):
self.fail('make_aware', timezone=field_timezone)
raise e
elif (field_timezone is None) and timezone.is_aware(value): elif (field_timezone is None) and timezone.is_aware(value):
return timezone.make_naive(value, datetime.timezone.utc) return timezone.make_naive(value, datetime.timezone.utc)
return value return value
@ -1392,6 +1402,10 @@ class ChoiceField(Field):
if data == '' and self.allow_blank: if data == '' and self.allow_blank:
return '' return ''
if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
str(data.value):
data = data.value
try: try:
return self.choice_strings_to_values[str(data)] return self.choice_strings_to_values[str(data)]
except KeyError: except KeyError:
@ -1400,6 +1414,11 @@ class ChoiceField(Field):
def to_representation(self, value): def to_representation(self, value):
if value in ('', None): if value in ('', None):
return value return value
if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
str(value.value):
value = value.value
return self.choice_strings_to_values.get(str(value), value) return self.choice_strings_to_values.get(str(value), value)
def iter_options(self): def iter_options(self):
@ -1423,7 +1442,8 @@ class ChoiceField(Field):
# Allows us to deal with eg. integer choices while supporting either # Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out. # integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = { self.choice_strings_to_values = {
str(key): key for key in self.choices str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
and str(key) != str(key.value) else str(key): key for key in self.choices
} }
choices = property(_get_choices, _set_choices) choices = property(_get_choices, _set_choices)
@ -1643,7 +1663,7 @@ class ListField(Field):
def run_child_validation(self, data): def run_child_validation(self, data):
result = [] result = []
errors = OrderedDict() errors = {}
for idx, item in enumerate(data): for idx, item in enumerate(data):
try: try:
@ -1707,7 +1727,7 @@ class DictField(Field):
def run_child_validation(self, data): def run_child_validation(self, data):
result = {} result = {}
errors = OrderedDict() errors = {}
for key, value in data.items(): for key, value in data.items():
key = str(key) key = str(key)

View File

@ -3,6 +3,7 @@ Provides generic filtering backends that can be used to filter the results
returned by list views. returned by list views.
""" """
import operator import operator
import warnings
from functools import reduce from functools import reduce
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -12,6 +13,7 @@ from django.template import loader
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema, distinct from rest_framework.compat import coreapi, coreschema, distinct
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -29,6 +31,8 @@ class BaseFilterBackend:
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [] return []
@ -146,6 +150,8 @@ class SearchFilter(BaseFilterBackend):
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [ return [
coreapi.Field( coreapi.Field(
@ -306,6 +312,8 @@ class OrderingFilter(BaseFilterBackend):
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [ return [
coreapi.Field( coreapi.Field(

View File

@ -6,8 +6,6 @@ some fairly ad-hoc information about the view.
Future implementations might use JSON schema or other definitions in order Future implementations might use JSON schema or other definitions in order
to return this information in a more standardized way. to return this information in a more standardized way.
""" """
from collections import OrderedDict
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -59,11 +57,12 @@ class SimpleMetadata(BaseMetadata):
}) })
def determine_metadata(self, request, view): def determine_metadata(self, request, view):
metadata = OrderedDict() metadata = {
metadata['name'] = view.get_view_name() "name": view.get_view_name(),
metadata['description'] = view.get_view_description() "description": view.get_view_description(),
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes] "renders": [renderer.media_type for renderer in view.renderer_classes],
metadata['parses'] = [parser.media_type for parser in view.parser_classes] "parses": [parser.media_type for parser in view.parser_classes],
}
if hasattr(view, 'get_serializer'): if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view) actions = self.determine_actions(request, view)
if actions: if actions:
@ -106,25 +105,27 @@ class SimpleMetadata(BaseMetadata):
# If this is a `ListSerializer` then we want to examine the # If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead. # underlying child serializer instance instead.
serializer = serializer.child serializer = serializer.child
return OrderedDict([ return {
(field_name, self.get_field_info(field)) field_name: self.get_field_info(field)
for field_name, field in serializer.fields.items() for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField) if not isinstance(field, serializers.HiddenField)
]) }
def get_field_info(self, field): def get_field_info(self, field):
""" """
Given an instance of a serializer field, return a dictionary Given an instance of a serializer field, return a dictionary
of metadata about it. of metadata about it.
""" """
field_info = OrderedDict() field_info = {
field_info['type'] = self.label_lookup[field] "type": self.label_lookup[field],
field_info['required'] = getattr(field, 'required', False) "required": getattr(field, "required", False),
}
attrs = [ attrs = [
'read_only', 'label', 'help_text', 'read_only', 'label', 'help_text',
'min_length', 'max_length', 'min_length', 'max_length',
'min_value', 'max_value' 'min_value', 'max_value',
'max_digits', 'decimal_places'
] ]
for attr in attrs: for attr in attrs:

View File

@ -4,16 +4,19 @@ be used for paginated responses.
""" """
import contextlib import contextlib
import warnings
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple from collections import namedtuple
from urllib import parse from urllib import parse
from django.core.paginator import InvalidPage from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator from django.core.paginator import Paginator as DjangoPaginator
from django.db.models import Q
from django.template import loader from django.template import loader
from django.utils.encoding import force_str from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import RemovedInDRF317Warning
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import NotFound from rest_framework.exceptions import NotFound
from rest_framework.response import Response from rest_framework.response import Response
@ -151,6 +154,8 @@ class BasePagination:
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
return [] return []
def get_schema_operation_parameters(self, view): def get_schema_operation_parameters(self, view):
@ -224,12 +229,12 @@ class PageNumberPagination(BasePagination):
return page_number return page_number
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response({
('count', self.page.paginator.count), 'count': self.page.paginator.count,
('next', self.get_next_link()), 'next': self.get_next_link(),
('previous', self.get_previous_link()), 'previous': self.get_previous_link(),
('results', data) 'results': data,
])) })
def get_paginated_response_schema(self, schema): def get_paginated_response_schema(self, schema):
return { return {
@ -310,6 +315,8 @@ class PageNumberPagination(BasePagination):
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [ fields = [
coreapi.Field( coreapi.Field(
@ -394,12 +401,12 @@ class LimitOffsetPagination(BasePagination):
return list(queryset[self.offset:self.offset + self.limit]) return list(queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response({
('count', self.count), 'count': self.count,
('next', self.get_next_link()), 'next': self.get_next_link(),
('previous', self.get_previous_link()), 'previous': self.get_previous_link(),
('results', data) 'results': data
])) })
def get_paginated_response_schema(self, schema): def get_paginated_response_schema(self, schema):
return { return {
@ -524,6 +531,8 @@ class LimitOffsetPagination(BasePagination):
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
return [ return [
coreapi.Field( coreapi.Field(
@ -620,7 +629,7 @@ class CursorPagination(BasePagination):
queryset = queryset.order_by(*self.ordering) queryset = queryset.order_by(*self.ordering)
# If we have a cursor with a fixed position then filter by that. # If we have a cursor with a fixed position then filter by that.
if current_position is not None: if str(current_position) != 'None':
order = self.ordering[0] order = self.ordering[0]
is_reversed = order.startswith('-') is_reversed = order.startswith('-')
order_attr = order.lstrip('-') order_attr = order.lstrip('-')
@ -631,7 +640,12 @@ class CursorPagination(BasePagination):
else: else:
kwargs = {order_attr + '__gt': current_position} kwargs = {order_attr + '__gt': current_position}
queryset = queryset.filter(**kwargs) filter_query = Q(**kwargs)
# If some records contain a null for the ordering field, don't lose them.
# When reverse ordering, nulls will come last and need to be included.
if (reverse and not is_reversed) or is_reversed:
filter_query |= Q(**{order_attr + '__isnull': True})
queryset = queryset.filter(filter_query)
# If we have an offset cursor then offset the entire page by that amount. # If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a # We also always fetch an extra item in order to determine if there is a
@ -704,7 +718,7 @@ class CursorPagination(BasePagination):
# The item in this position and the item following it # The item in this position and the item following it
# have different positions. We can use this position as # have different positions. We can use this position as
# our marker. # our marker.
has_item_with_unique_position = True has_item_with_unique_position = position is not None
break break
# The item in this position has the same position as the item # The item in this position has the same position as the item
@ -757,7 +771,7 @@ class CursorPagination(BasePagination):
# The item in this position and the item following it # The item in this position and the item following it
# have different positions. We can use this position as # have different positions. We can use this position as
# our marker. # our marker.
has_item_with_unique_position = True has_item_with_unique_position = position is not None
break break
# The item in this position has the same position as the item # The item in this position has the same position as the item
@ -795,6 +809,10 @@ class CursorPagination(BasePagination):
""" """
Return a tuple of strings, that may be used in an `order_by` method. Return a tuple of strings, that may be used in an `order_by` method.
""" """
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
ordering_filters = [ ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', []) filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering') if hasattr(filter_cls, 'get_ordering')
@ -805,26 +823,19 @@ class CursorPagination(BasePagination):
# then we defer to that filter to determine the ordering. # then we defer to that filter to determine the ordering.
filter_cls = ordering_filters[0] filter_cls = ordering_filters[0]
filter_instance = filter_cls() filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view) ordering_from_filter = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, ( if ordering_from_filter:
'Using cursor pagination, but filter class {filter_cls} ' ordering = ordering_from_filter
'returned a `None` ordering.'.format(
filter_cls=filter_cls.__name__ assert ordering is not None, (
) 'Using cursor pagination, but no ordering attribute was declared '
) 'on the pagination class.'
else: )
# The default case is to check for an `ordering` attribute assert '__' not in ordering, (
# on this pagination instance. 'Cursor pagination does not support double underscore lookups '
ordering = self.ordering 'for orderings. Orderings should be an unchanging, unique or '
assert ordering is not None, ( 'nearly-unique field on the model, such as "-created" or "pk".'
'Using cursor pagination, but no ordering attribute was declared ' )
'on the pagination class.'
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
'nearly-unique field on the model, such as "-created" or "pk".'
)
assert isinstance(ordering, (str, list, tuple)), ( assert isinstance(ordering, (str, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format( 'Invalid ordering. Expected string or tuple, but got {type}'.format(
@ -883,14 +894,14 @@ class CursorPagination(BasePagination):
attr = instance[field_name] attr = instance[field_name]
else: else:
attr = getattr(instance, field_name) attr = getattr(instance, field_name)
return str(attr) return None if attr is None else str(attr)
def get_paginated_response(self, data): def get_paginated_response(self, data):
return Response(OrderedDict([ return Response({
('next', self.get_next_link()), 'next': self.get_next_link(),
('previous', self.get_previous_link()), 'previous': self.get_previous_link(),
('results', data) 'results': data,
])) })
def get_paginated_response_schema(self, schema): def get_paginated_response_schema(self, schema):
return { return {
@ -927,6 +938,8 @@ class CursorPagination(BasePagination):
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`' assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
fields = [ fields = [
coreapi.Field( coreapi.Field(

View File

@ -1,6 +1,6 @@
import contextlib import contextlib
import sys import sys
from collections import OrderedDict from operator import attrgetter
from urllib import parse from urllib import parse
from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
@ -71,6 +71,7 @@ class PKOnlyObject:
instance, but still want to return an object with a .pk attribute, instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance. in order to keep the same interface as a regular model instance.
""" """
def __init__(self, pk): def __init__(self, pk):
self.pk = pk self.pk = pk
@ -197,13 +198,9 @@ class RelatedField(Field):
if cutoff is not None: if cutoff is not None:
queryset = queryset[:cutoff] queryset = queryset[:cutoff]
return OrderedDict([ return {
( self.to_representation(item): self.display_value(item) for item in queryset
self.to_representation(item), }
self.display_value(item)
)
for item in queryset
])
@property @property
def choices(self): def choices(self):
@ -464,7 +461,11 @@ class SlugRelatedField(RelatedField):
self.fail('invalid') self.fail('invalid')
def to_representation(self, obj): def to_representation(self, obj):
return getattr(obj, self.slug_field) slug = self.slug_field
if "__" in slug:
# handling nested relationship if defined
slug = slug.replace('__', '.')
return attrgetter(slug)(obj)
class ManyRelatedField(Field): class ManyRelatedField(Field):

View File

@ -9,7 +9,7 @@ REST framework also provides an HTML renderer that renders the browsable API.
import base64 import base64
import contextlib import contextlib
from collections import OrderedDict import datetime
from urllib import parse from urllib import parse
from django import forms from django import forms
@ -507,6 +507,9 @@ class BrowsableAPIRenderer(BaseRenderer):
return self.render_form_for_serializer(serializer) return self.render_form_for_serializer(serializer)
def render_form_for_serializer(self, serializer): def render_form_for_serializer(self, serializer):
if isinstance(serializer, serializers.ListSerializer):
return None
if hasattr(serializer, 'initial_data'): if hasattr(serializer, 'initial_data'):
serializer.is_valid() serializer.is_valid()
@ -556,10 +559,13 @@ class BrowsableAPIRenderer(BaseRenderer):
context['indent'] = 4 context['indent'] = 4
# strip HiddenField from output # strip HiddenField from output
is_list_serializer = isinstance(serializer, serializers.ListSerializer)
serializer = serializer.child if is_list_serializer else serializer
data = serializer.data.copy() data = serializer.data.copy()
for name, field in serializer.fields.items(): for name, field in serializer.fields.items():
if isinstance(field, serializers.HiddenField): if isinstance(field, serializers.HiddenField):
data.pop(name, None) data.pop(name, None)
data = [data] if is_list_serializer else data
content = renderer.render(data, accepted, context) content = renderer.render(data, accepted, context)
# Renders returns bytes, but CharField expects a str. # Renders returns bytes, but CharField expects a str.
content = content.decode() content = content.decode()
@ -653,7 +659,7 @@ class BrowsableAPIRenderer(BaseRenderer):
raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request) raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
response_headers = OrderedDict(sorted(response.items())) response_headers = dict(sorted(response.items()))
renderer_content_type = '' renderer_content_type = ''
if renderer: if renderer:
renderer_content_type = '%s' % renderer.media_type renderer_content_type = '%s' % renderer.media_type
@ -1057,6 +1063,7 @@ class OpenAPIRenderer(BaseRenderer):
def ignore_aliases(self, data): def ignore_aliases(self, data):
return True return True
Dumper.add_representer(SafeString, Dumper.represent_str) Dumper.add_representer(SafeString, Dumper.represent_str)
Dumper.add_representer(datetime.timedelta, encoders.CustomScalar.represent_timedelta)
return yaml.dump(data, default_flow_style=False, sort_keys=False, Dumper=Dumper).encode('utf-8') return yaml.dump(data, default_flow_style=False, sort_keys=False, Dumper=Dumper).encode('utf-8')

View File

@ -14,7 +14,7 @@ For example, you might have a `urls.py` that looks something like this:
urlpatterns = router.urls urlpatterns = router.urls
""" """
import itertools import itertools
from collections import OrderedDict, namedtuple from collections import namedtuple
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.urls import NoReverseMatch, path, re_path from django.urls import NoReverseMatch, path, re_path
@ -321,7 +321,7 @@ class APIRootView(views.APIView):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
# Return a plain {"name": "hyperlink"} response. # Return a plain {"name": "hyperlink"} response.
ret = OrderedDict() ret = {}
namespace = request.resolver_match.namespace namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items(): for key, url_name in self.api_root_dict.items():
if namespace: if namespace:
@ -365,7 +365,7 @@ class DefaultRouter(SimpleRouter):
""" """
Return a basic root view. Return a basic root view.
""" """
api_root_dict = OrderedDict() api_root_dict = {}
list_name = self.routes[0].name list_name = self.routes[0].name
for prefix, viewset, basename in self.registry: for prefix, viewset, basename in self.registry:
api_root_dict[prefix] = list_name.format(basename=basename) api_root_dict[prefix] = list_name.format(basename=basename)

View File

@ -1,11 +1,11 @@
import warnings import warnings
from collections import Counter, OrderedDict from collections import Counter
from urllib import parse from urllib import parse
from django.db import models from django.db import models
from django.utils.encoding import force_str from django.utils.encoding import force_str
from rest_framework import exceptions, serializers from rest_framework import RemovedInDRF317Warning, exceptions, serializers
from rest_framework.compat import coreapi, coreschema, uritemplate from rest_framework.compat import coreapi, coreschema, uritemplate
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -54,7 +54,7 @@ to customise schema structure.
""" """
class LinkNode(OrderedDict): class LinkNode(dict):
def __init__(self): def __init__(self):
self.links = [] self.links = []
self.methods_counter = Counter() self.methods_counter = Counter()
@ -118,6 +118,8 @@ class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None): def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
assert coreapi, '`coreapi` must be installed for schema support.' assert coreapi, '`coreapi` must be installed for schema support.'
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert coreschema, '`coreschema` must be installed for schema support.' assert coreschema, '`coreschema` must be installed for schema support.'
super().__init__(title, url, description, patterns, urlconf) super().__init__(title, url, description, patterns, urlconf)
@ -268,11 +270,11 @@ def field_to_schema(field):
) )
elif isinstance(field, serializers.Serializer): elif isinstance(field, serializers.Serializer):
return coreschema.Object( return coreschema.Object(
properties=OrderedDict([ properties={
(key, field_to_schema(value)) key: field_to_schema(value)
for key, value for key, value
in field.fields.items() in field.fields.items()
]), },
title=title, title=title,
description=description description=description
) )
@ -351,6 +353,9 @@ class AutoSchema(ViewInspector):
will be added to auto-generated fields, overwriting on `Field.name` will be added to auto-generated fields, overwriting on `Field.name`
""" """
super().__init__() super().__init__()
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
if manual_fields is None: if manual_fields is None:
manual_fields = [] manual_fields = []
self._manual_fields = manual_fields self._manual_fields = manual_fields
@ -549,7 +554,7 @@ class AutoSchema(ViewInspector):
if not update_with: if not update_with:
return fields return fields
by_name = OrderedDict((f.name, f) for f in fields) by_name = {f.name: f for f in fields}
for f in update_with: for f in update_with:
by_name[f.name] = f by_name[f.name] = f
fields = list(by_name.values()) fields = list(by_name.values())
@ -592,6 +597,9 @@ class ManualSchema(ViewInspector):
* `description`: String description for view. Optional. * `description`: String description for view. Optional.
""" """
super().__init__() super().__init__()
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
self._fields = fields self._fields = fields
self._description = description self._description = description
@ -613,4 +621,6 @@ class ManualSchema(ViewInspector):
def is_enabled(): def is_enabled():
"""Is CoreAPI Mode enabled?""" """Is CoreAPI Mode enabled?"""
if coreapi is not None:
warnings.warn('CoreAPI compatibility is deprecated and will be removed in DRF 3.17', RemovedInDRF317Warning)
return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema) return issubclass(api_settings.DEFAULT_SCHEMA_CLASS, AutoSchema)

View File

@ -1,6 +1,5 @@
import re import re
import warnings import warnings
from collections import OrderedDict
from decimal import Decimal from decimal import Decimal
from operator import attrgetter from operator import attrgetter
from urllib.parse import urljoin from urllib.parse import urljoin
@ -340,7 +339,7 @@ class AutoSchema(ViewInspector):
return paginator.get_schema_operation_parameters(view) return paginator.get_schema_operation_parameters(view)
def map_choicefield(self, field): def map_choicefield(self, field):
choices = list(OrderedDict.fromkeys(field.choices)) # preserve order and remove duplicates choices = list(dict.fromkeys(field.choices)) # preserve order and remove duplicates
if all(isinstance(choice, bool) for choice in choices): if all(isinstance(choice, bool) for choice in choices):
type = 'boolean' type = 'boolean'
elif all(isinstance(choice, int) for choice in choices): elif all(isinstance(choice, int) for choice in choices):

View File

@ -15,7 +15,7 @@ import contextlib
import copy import copy
import inspect import inspect
import traceback import traceback
from collections import OrderedDict, defaultdict from collections import defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
@ -315,7 +315,7 @@ class SerializerMetaclass(type):
for name, f in base._declared_fields.items() if name not in known for name, f in base._declared_fields.items() if name not in known
] ]
return OrderedDict(base_fields + fields) return dict(base_fields + fields)
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
@ -353,6 +353,26 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
} }
def set_value(self, dictionary, keys, value):
"""
Similar to Python's built in `dictionary[key] = value`,
but takes a list of nested keys instead of a single key.
set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
if not keys:
dictionary.update(value)
return
for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
@cached_property @cached_property
def fields(self): def fields(self):
""" """
@ -400,20 +420,20 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
if hasattr(self, 'initial_data'): if hasattr(self, 'initial_data'):
# initial_data may not be a valid type # initial_data may not be a valid type
if not isinstance(self.initial_data, Mapping): if not isinstance(self.initial_data, Mapping):
return OrderedDict() return {}
return OrderedDict([ return {
(field_name, field.get_value(self.initial_data)) field_name: field.get_value(self.initial_data)
for field_name, field in self.fields.items() for field_name, field in self.fields.items()
if (field.get_value(self.initial_data) is not empty) and if (field.get_value(self.initial_data) is not empty) and
not field.read_only not field.read_only
]) }
return OrderedDict([ return {
(field.field_name, field.get_initial()) field.field_name: field.get_initial()
for field in self.fields.values() for field in self.fields.values()
if not field.read_only if not field.read_only
]) }
def get_value(self, dictionary): def get_value(self, dictionary):
# We override the default field access in order to support # We override the default field access in order to support
@ -448,7 +468,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
] ]
defaults = OrderedDict() defaults = {}
for field in fields: for field in fields:
try: try:
default = field.get_default() default = field.get_default()
@ -481,8 +501,8 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
api_settings.NON_FIELD_ERRORS_KEY: [message] api_settings.NON_FIELD_ERRORS_KEY: [message]
}, code='invalid') }, code='invalid')
ret = OrderedDict() ret = {}
errors = OrderedDict() errors = {}
fields = self._writable_fields fields = self._writable_fields
for field in fields: for field in fields:
@ -499,7 +519,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
except SkipField: except SkipField:
pass pass
else: else:
set_value(ret, field.source_attrs, validated_value) self.set_value(ret, field.source_attrs, validated_value)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
@ -510,7 +530,7 @@ class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
""" """
Object instance -> Dict of primitive datatypes. Object instance -> Dict of primitive datatypes.
""" """
ret = OrderedDict() ret = {}
fields = self._readable_fields fields = self._readable_fields
for field in fields: for field in fields:
@ -596,6 +616,12 @@ class ListSerializer(BaseSerializer):
self.min_length = kwargs.pop('min_length', None) self.min_length = kwargs.pop('min_length', None)
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(self.child), '`child` has not been instantiated.'
instance = kwargs.get('instance', [])
data = kwargs.get('data', [])
if instance and data:
assert len(data) == len(instance), 'Data and instance should have same length'
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self) self.child.bind(field_name='', parent=self)
@ -670,7 +696,13 @@ class ListSerializer(BaseSerializer):
ret = [] ret = []
errors = [] errors = []
for item in data: for idx, item in enumerate(data):
if (
hasattr(self, 'instance')
and self.instance
and len(self.instance) > idx
):
self.child.instance = self.instance[idx]
try: try:
validated = self.child.run_validation(item) validated = self.child.run_validation(item)
except ValidationError as exc: except ValidationError as exc:
@ -1068,7 +1100,7 @@ class ModelSerializer(Serializer):
) )
# Determine the fields that should be included on the serializer. # Determine the fields that should be included on the serializer.
fields = OrderedDict() fields = {}
for field_name in field_names: for field_name in field_names:
# If the field is explicitly declared on the class then use that. # If the field is explicitly declared on the class then use that.
@ -1553,16 +1585,16 @@ class ModelSerializer(Serializer):
# which may map onto a model field. Any dotted field name lookups # which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not # cannot map to a field, and must be a traversal, so we're not
# including those. # including those.
field_sources = OrderedDict( field_sources = {
(field.field_name, field.source) for field in self._writable_fields field.field_name: field.source for field in self._writable_fields
if (field.source != '*') and ('.' not in field.source) if (field.source != '*') and ('.' not in field.source)
) }
# Special Case: Add read_only fields with defaults. # Special Case: Add read_only fields with defaults.
field_sources.update(OrderedDict( field_sources.update({
(field.field_name, field.source) for field in self.fields.values() field.field_name: field.source for field in self.fields.values()
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
)) })
# Invert so we can find the serializer field names that correspond to # Invert so we can find the serializer field names that correspond to
# the model field names in the unique_together sets. This also allows # the model field names in the unique_together sets. This also allows

File diff suppressed because one or more lines are too long

View File

@ -250,7 +250,7 @@
"csrfToken": "{{ csrf_token }}" "csrfToken": "{{ csrf_token }}"
} }
</script> </script>
<script src="{% static "rest_framework/js/jquery-3.5.1.min.js" %}"></script> <script src="{% static "rest_framework/js/jquery-3.6.4.min.js" %}"></script>
<script src="{% static "rest_framework/js/ajax-form.js" %}"></script> <script src="{% static "rest_framework/js/ajax-form.js" %}"></script>
<script src="{% static "rest_framework/js/csrf.js" %}"></script> <script src="{% static "rest_framework/js/csrf.js" %}"></script>
<script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script> <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>

View File

@ -293,7 +293,7 @@
"csrfToken": "{% if request %}{{ csrf_token }}{% endif %}" "csrfToken": "{% if request %}{{ csrf_token }}{% endif %}"
} }
</script> </script>
<script src="{% static "rest_framework/js/jquery-3.5.1.min.js" %}"></script> <script src="{% static "rest_framework/js/jquery-3.6.4.min.js" %}"></script>
<script src="{% static "rest_framework/js/ajax-form.js" %}"></script> <script src="{% static "rest_framework/js/ajax-form.js" %}"></script>
<script src="{% static "rest_framework/js/csrf.js" %}"></script> <script src="{% static "rest_framework/js/csrf.js" %}"></script>
<script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script> <script src="{% static "rest_framework/js/bootstrap.min.js" %}"></script>

View File

@ -30,7 +30,7 @@ being applied unexpectedly?</p>
<p>Your response status code is: <code>{{ response.status_code }}</code></p> <p>Your response status code is: <code>{{ response.status_code }}</code></p>
<h3>401 Unauthorised.</h3> <h3>401 Unauthorized.</h3>
<ul> <ul>
<li>Do you have SessionAuthentication enabled?</li> <li>Do you have SessionAuthentication enabled?</li>
<li>Are you logged in?</li> <li>Are you logged in?</li>
@ -66,6 +66,6 @@ at <code>rest_framework/docs/error.html</code>.</p>
<script src="{% static 'rest_framework/js/jquery-3.5.1.min.js' %}"></script> <script src="{% static 'rest_framework/js/jquery-3.6.4.min.js' %}"></script>
</body> </body>
</html> </html>

View File

@ -38,7 +38,7 @@
{% include "rest_framework/docs/auth/basic.html" %} {% include "rest_framework/docs/auth/basic.html" %}
{% include "rest_framework/docs/auth/session.html" %} {% include "rest_framework/docs/auth/session.html" %}
<script src="{% static 'rest_framework/js/jquery-3.5.1.min.js' %}"></script> <script src="{% static 'rest_framework/js/jquery-3.6.4.min.js' %}"></script>
<script src="{% static 'rest_framework/js/bootstrap.min.js' %}"></script> <script src="{% static 'rest_framework/js/bootstrap.min.js' %}"></script>
<script src="{% static 'rest_framework/docs/js/jquery.json-view.min.js' %}"></script> <script src="{% static 'rest_framework/docs/js/jquery.json-view.min.js' %}"></script>
<script src="{% static 'rest_framework/docs/js/api.js' %}"></script> <script src="{% static 'rest_framework/docs/js/api.js' %}"></script>

View File

@ -11,7 +11,7 @@
{% endif %} {% endif %}
<div class="col-sm-10"> <div class="col-sm-10">
<select multiple {{ field.choices|yesno:",disabled" }} class="form-control" name="{{ field.name }}"> <select multiple class="form-control" name="{{ field.name }}">
{% for select in field.iter_options %} {% for select in field.iter_options %}
{% if select.start_option_group %} {% if select.start_option_group %}
<optgroup label="{{ select.label }}"> <optgroup label="{{ select.label }}">

View File

@ -1,5 +1,4 @@
import re import re
from collections import OrderedDict
from django import template from django import template
from django.template import loader from django.template import loader
@ -49,10 +48,10 @@ def with_location(fields, location):
@register.simple_tag @register.simple_tag
def form_for_link(link): def form_for_link(link):
import coreschema import coreschema
properties = OrderedDict([ properties = {
(field.name, field.schema or coreschema.String()) field.name: field.schema or coreschema.String()
for field in link.fields for field in link.fields
]) }
required = [ required = [
field.name field.name
for field in link.fields for field in link.fields
@ -272,7 +271,7 @@ def schema_links(section, sec_key=None):
links.update(new_links) links.update(new_links)
if sec_key is not None: if sec_key is not None:
new_links = OrderedDict() new_links = {}
for link_key, link in links.items(): for link_key, link in links.items():
new_key = NESTED_FORMAT % (sec_key, link_key) new_key = NESTED_FORMAT % (sec_key, link_key)
new_links.update({new_key: link}) new_links.update({new_key: link})

View File

@ -65,3 +65,14 @@ class JSONEncoder(json.JSONEncoder):
elif hasattr(obj, '__iter__'): elif hasattr(obj, '__iter__'):
return tuple(item for item in obj) return tuple(item for item in obj)
return super().default(obj) return super().default(obj)
class CustomScalar:
"""
CustomScalar that knows how to encode timedelta that renderer
can understand.
"""
@classmethod
def represent_timedelta(cls, dumper, data):
value = str(data.total_seconds())
return dumper.represent_scalar('tag:yaml.org,2002:str', value)

View File

@ -5,7 +5,7 @@ relationships and their associated metadata.
Usage: `get_field_info(model)` returns a `FieldInfo` instance. Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import OrderedDict, namedtuple from collections import namedtuple
FieldInfo = namedtuple('FieldResult', [ FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance 'pk', # Model field instance
@ -58,7 +58,7 @@ def _get_pk(opts):
def _get_fields(opts): def _get_fields(opts):
fields = OrderedDict() fields = {}
for field in [field for field in opts.fields if field.serialize and not field.remote_field]: for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
fields[field.name] = field fields[field.name] = field
@ -71,9 +71,9 @@ def _get_to_field(field):
def _get_forward_relationships(opts): def _get_forward_relationships(opts):
""" """
Returns an `OrderedDict` of field names to `RelationInfo`. Returns a dict of field names to `RelationInfo`.
""" """
forward_relations = OrderedDict() forward_relations = {}
for field in [field for field in opts.fields if field.serialize and field.remote_field]: for field in [field for field in opts.fields if field.serialize and field.remote_field]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
@ -103,9 +103,9 @@ def _get_forward_relationships(opts):
def _get_reverse_relationships(opts): def _get_reverse_relationships(opts):
""" """
Returns an `OrderedDict` of field names to `RelationInfo`. Returns a dict of field names to `RelationInfo`.
""" """
reverse_relations = OrderedDict() reverse_relations = {}
all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many] all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many]
for relation in all_related_objects: for relation in all_related_objects:
accessor_name = relation.get_accessor_name() accessor_name = relation.get_accessor_name()
@ -139,19 +139,14 @@ def _get_reverse_relationships(opts):
def _merge_fields_and_pk(pk, fields): def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict() fields_and_pk = {'pk': pk, pk.name: pk}
fields_and_pk['pk'] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields) fields_and_pk.update(fields)
return fields_and_pk return fields_and_pk
def _merge_relationships(forward_relations, reverse_relations): def _merge_relationships(forward_relations, reverse_relations):
return OrderedDict( return {**forward_relations, **reverse_relations}
list(forward_relations.items()) +
list(reverse_relations.items())
)
def is_abstract_model(model): def is_abstract_model(model):

View File

@ -1,6 +1,5 @@
import contextlib import contextlib
import sys import sys
from collections import OrderedDict
from collections.abc import Mapping, MutableMapping from collections.abc import Mapping, MutableMapping
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -8,7 +7,7 @@ from django.utils.encoding import force_str
from rest_framework.utils import json from rest_framework.utils import json
class ReturnDict(OrderedDict): class ReturnDict(dict):
""" """
Return object from `serializer.data` for the `Serializer` class. Return object from `serializer.data` for the `Serializer` class.
Includes a backlink to the serializer instance for renderers Includes a backlink to the serializer instance for renderers
@ -161,7 +160,7 @@ class BindingDict(MutableMapping):
def __init__(self, serializer): def __init__(self, serializer):
self.serializer = serializer self.serializer = serializer
self.fields = OrderedDict() self.fields = {}
def __setitem__(self, key, field): def __setitem__(self, key, field):
self.fields[key] = field self.fields[key] = field

View File

@ -0,0 +1,25 @@
from datetime import datetime, timezone, tzinfo
def datetime_exists(dt):
"""Check if a datetime exists. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# There are no non-existent times in UTC, and comparisons between
# aware time zones always compare absolute times; if a datetime is
# not equal to the same datetime represented in UTC, it is imaginary.
return dt.astimezone(timezone.utc) == dt
def datetime_ambiguous(dt: datetime):
"""Check whether a datetime is ambiguous. Taken from: https://pytz-deprecation-shim.readthedocs.io/en/latest/migration.html"""
# If a datetime exists and its UTC offset changes in response to
# changing `fold`, it is ambiguous in the zone specified.
return datetime_exists(dt) and (
dt.replace(fold=not dt.fold).utcoffset() != dt.utcoffset()
)
def valid_datetime(dt):
"""Returns True if the datetime is not ambiguous or imaginary, False otherwise."""
if isinstance(dt.tzinfo, tzinfo) and not datetime_ambiguous(dt):
return True
return False

View File

@ -79,6 +79,15 @@ class UniqueValidator:
smart_repr(self.queryset) smart_repr(self.queryset)
) )
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.lookup == other.lookup
)
class UniqueTogetherValidator: class UniqueTogetherValidator:
""" """
@ -166,6 +175,16 @@ class UniqueTogetherValidator:
smart_repr(self.fields) smart_repr(self.fields)
) )
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.requires_context == other.requires_context
and self.missing_message == other.missing_message
and self.queryset == other.queryset
and self.fields == other.fields
)
class ProhibitSurrogateCharactersValidator: class ProhibitSurrogateCharactersValidator:
message = _('Surrogate characters are not allowed: U+{code_point:X}.') message = _('Surrogate characters are not allowed: U+{code_point:X}.')
@ -177,6 +196,13 @@ class ProhibitSurrogateCharactersValidator:
message = self.message.format(code_point=ord(surrogate_character)) message = self.message.format(code_point=ord(surrogate_character))
raise ValidationError(message, code=self.code) raise ValidationError(message, code=self.code)
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.code == other.code
)
class BaseUniqueForValidator: class BaseUniqueForValidator:
message = None message = None
@ -230,6 +256,17 @@ class BaseUniqueForValidator:
self.field: message self.field: message
}, code='unique') }, code='unique')
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
return (self.message == other.message
and self.missing_message == other.missing_message
and self.requires_context == other.requires_context
and self.queryset == other.queryset
and self.field == other.field
and self.date_field == other.date_field
)
def __repr__(self): def __repr__(self):
return '<%s(queryset=%s, field=%s, date_field=%s)>' % ( return '<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__, self.__class__.__name__,

View File

@ -81,8 +81,10 @@ class URLPathVersioning(BaseVersioning):
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None: if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs kwargs = {
kwargs[self.version_param] = request.version self.version_param: request.version,
**(kwargs or {})
}
return super().reverse( return super().reverse(
viewname, args, kwargs, request, format, **extra viewname, args, kwargs, request, format, **extra
@ -117,15 +119,16 @@ class NamespaceVersioning(BaseVersioning):
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
resolver_match = getattr(request, 'resolver_match', None) resolver_match = getattr(request, 'resolver_match', None)
if resolver_match is None or not resolver_match.namespace: if resolver_match is not None and resolver_match.namespace:
return self.default_version # Allow for possibly nested namespaces.
possible_versions = resolver_match.namespace.split(':')
for version in possible_versions:
if self.is_allowed_version(version):
return version
# Allow for possibly nested namespaces. if not self.is_allowed_version(self.default_version):
possible_versions = resolver_match.namespace.split(':') raise exceptions.NotFound(self.invalid_version_message)
for version in possible_versions: return self.default_version
if self.is_allowed_version(version):
return version
raise exceptions.NotFound(self.invalid_version_message)
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra): def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
if request.version is not None: if request.version is not None:

View File

@ -16,7 +16,6 @@ automatically.
router.register(r'users', UserViewSet, 'user') router.register(r'users', UserViewSet, 'user')
urlpatterns = router.urls urlpatterns = router.urls
""" """
from collections import OrderedDict
from functools import update_wrapper from functools import update_wrapper
from inspect import getmembers from inspect import getmembers
@ -183,7 +182,7 @@ class ViewSetMixin:
This method will noop if `detail` was not provided as a view initkwarg. This method will noop if `detail` was not provided as a view initkwarg.
""" """
action_urls = OrderedDict() action_urls = {}
# exit early if `detail` has not been provided # exit early if `detail` has not been provided
if self.detail is None: if self.detail is None:

View File

@ -3,6 +3,8 @@ license_files = LICENSE.md
[tool:pytest] [tool:pytest]
addopts=--tb=short --strict-markers -ra addopts=--tb=short --strict-markers -ra
testspath = tests
filterwarnings = ignore:CoreAPI compatibility is deprecated*:rest_framework.RemovedInDRF317Warning
[flake8] [flake8]
ignore = E501,W503,W504 ignore = E501,W503,W504

View File

@ -37,7 +37,8 @@ an older version of Django REST Framework:
def read(f): def read(f):
return open(f, 'r', encoding='utf-8').read() with open(f, 'r', encoding='utf-8') as file:
return file.read()
def get_version(package): def get_version(package):
@ -82,7 +83,7 @@ setup(
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*) author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=find_packages(exclude=['tests*']), packages=find_packages(exclude=['tests*']),
include_package_data=True, include_package_data=True,
install_requires=["django>=3.0", "pytz"], install_requires=["django>=3.0", 'backports.zoneinfo;python_version<"3.9"'],
python_requires=">=3.6", python_requires=">=3.6",
zip_safe=False, zip_safe=False,
classifiers=[ classifiers=[

View File

@ -7,16 +7,24 @@ from django.test import TestCase, override_settings
from django.urls import include, path from django.urls import include, path
from rest_framework import ( from rest_framework import (
filters, generics, pagination, permissions, serializers RemovedInDRF317Warning, filters, generics, pagination, permissions,
serializers
) )
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import action, api_view, schema from rest_framework.decorators import action, api_view, schema
from rest_framework.filters import (
BaseFilterBackend, OrderingFilter, SearchFilter
)
from rest_framework.pagination import (
BasePagination, CursorPagination, LimitOffsetPagination,
PageNumberPagination
)
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.schemas import ( from rest_framework.schemas import (
AutoSchema, ManualSchema, SchemaGenerator, get_schema_view AutoSchema, ManualSchema, SchemaGenerator, get_schema_view
) )
from rest_framework.schemas.coreapi import field_to_schema from rest_framework.schemas.coreapi import field_to_schema, is_enabled
from rest_framework.schemas.generators import EndpointEnumerator from rest_framework.schemas.generators import EndpointEnumerator
from rest_framework.schemas.utils import is_list_view from rest_framework.schemas.utils import is_list_view
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
@ -1433,3 +1441,46 @@ def test_schema_handles_exception():
response.render() response.render()
assert response.status_code == 403 assert response.status_code == 403
assert b"You do not have permission to perform this action." in response.content assert b"You do not have permission to perform this action." in response.content
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
def test_coreapi_deprecation():
with pytest.warns(RemovedInDRF317Warning):
SchemaGenerator()
with pytest.warns(RemovedInDRF317Warning):
AutoSchema()
with pytest.warns(RemovedInDRF317Warning):
ManualSchema({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = OrderingFilter()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = BaseFilterBackend()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
deprecated_filter = SearchFilter()
deprecated_filter.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = BasePagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = PageNumberPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = LimitOffsetPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
paginator = CursorPagination()
paginator.get_schema_fields({})
with pytest.warns(RemovedInDRF317Warning):
is_enabled()

View File

@ -1162,6 +1162,31 @@ class TestGenerator(TestCase):
assert b'"openapi": "' in ret assert b'"openapi": "' in ret
assert b'"default": "0.0"' in ret assert b'"default": "0.0"' in ret
def test_schema_rendering_to_yaml(self):
patterns = [
path('example/', views.ExampleGenericAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
ret = OpenAPIRenderer().render(schema)
assert b"openapi: " in ret
assert b"default: '0.0'" in ret
def test_schema_rendering_timedelta_to_yaml_with_validator(self):
patterns = [
path('example/', views.ExampleValidatedAPIView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
ret = OpenAPIRenderer().render(schema)
assert b"openapi: " in ret
assert b"duration:\n type: string\n minimum: \'10.0\'\n" in ret
def test_schema_with_no_paths(self): def test_schema_with_no_paths(self):
patterns = [] patterns = []
generator = SchemaGenerator(patterns=patterns) generator = SchemaGenerator(patterns=patterns)

View File

@ -134,6 +134,11 @@ class ExampleValidatedSerializer(serializers.Serializer):
ip4 = serializers.IPAddressField(protocol='ipv4') ip4 = serializers.IPAddressField(protocol='ipv4')
ip6 = serializers.IPAddressField(protocol='ipv6') ip6 = serializers.IPAddressField(protocol='ipv6')
ip = serializers.IPAddressField() ip = serializers.IPAddressField()
duration = serializers.DurationField(
validators=(
MinValueValidator(timedelta(seconds=10)),
)
)
class ExampleValidatedAPIView(generics.GenericAPIView): class ExampleValidatedAPIView(generics.GenericAPIView):

View File

@ -5,10 +5,18 @@ import re
import sys import sys
import uuid import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal from decimal import ROUND_DOWN, ROUND_UP, Decimal
from enum import auto
from unittest.mock import patch
import pytest import pytest
import pytz
try:
import pytz
except ImportError:
pytz = None
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
from django.db.models import IntegerChoices, TextChoices
from django.http import QueryDict from django.http import QueryDict
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.utils.timezone import activate, deactivate, override from django.utils.timezone import activate, deactivate, override
@ -21,6 +29,11 @@ from rest_framework.fields import (
) )
from tests.models import UUIDForeignKeyTarget from tests.models import UUIDForeignKeyTarget
if sys.version_info >= (3, 9):
from zoneinfo import ZoneInfo
else:
from backports.zoneinfo import ZoneInfo
utc = datetime.timezone.utc utc = datetime.timezone.utc
# Tests for helper functions. # Tests for helper functions.
@ -651,7 +664,7 @@ class FieldValues:
""" """
Base class for testing valid and invalid input values. Base class for testing valid and invalid input values.
""" """
def test_valid_inputs(self): def test_valid_inputs(self, *args):
""" """
Ensure that valid values return the expected validated data. Ensure that valid values return the expected validated data.
""" """
@ -659,7 +672,7 @@ class FieldValues:
assert self.field.run_validation(input_value) == expected_output, \ assert self.field.run_validation(input_value) == expected_output, \
'input value: {}'.format(repr(input_value)) 'input value: {}'.format(repr(input_value))
def test_invalid_inputs(self): def test_invalid_inputs(self, *args):
""" """
Ensure that invalid values raise the expected validation error. Ensure that invalid values raise the expected validation error.
""" """
@ -669,7 +682,7 @@ class FieldValues:
assert exc_info.value.detail == expected_failure, \ assert exc_info.value.detail == expected_failure, \
'input value: {}'.format(repr(input_value)) 'input value: {}'.format(repr(input_value))
def test_outputs(self): def test_outputs(self, *args):
for output_value, expected_output in get_items(self.outputs): for output_value, expected_output in get_items(self.outputs):
assert self.field.to_representation(output_value) == expected_output, \ assert self.field.to_representation(output_value) == expected_output, \
'output value: {}'.format(repr(output_value)) 'output value: {}'.format(repr(output_value))
@ -682,8 +695,24 @@ class TestBooleanField(FieldValues):
Valid and invalid values for `BooleanField`. Valid and invalid values for `BooleanField`.
""" """
valid_inputs = { valid_inputs = {
'True': True,
'TRUE': True,
'tRuE': True,
't': True,
'T': True,
'true': True, 'true': True,
'on': True,
'ON': True,
'oN': True,
'False': False,
'FALSE': False,
'fALse': False,
'f': False,
'F': False,
'false': False, 'false': False,
'off': False,
'OFF': False,
'oFf': False,
'1': True, '1': True,
'0': False, '0': False,
1: True, 1: True,
@ -696,8 +725,24 @@ class TestBooleanField(FieldValues):
None: ['This field may not be null.'] None: ['This field may not be null.']
} }
outputs = { outputs = {
'True': True,
'TRUE': True,
'tRuE': True,
't': True,
'T': True,
'true': True, 'true': True,
'on': True,
'ON': True,
'oN': True,
'False': False,
'FALSE': False,
'fALse': False,
'f': False,
'F': False,
'false': False, 'false': False,
'off': False,
'OFF': False,
'oFf': False,
'1': True, '1': True,
'0': False, '0': False,
1: True, 1: True,
@ -1208,6 +1253,17 @@ class TestMinMaxDecimalField(FieldValues):
min_value=10, max_value=20 min_value=10, max_value=20
) )
def test_warning_when_not_decimal_types(self, caplog):
import logging
serializers.DecimalField(
max_digits=3, decimal_places=1,
min_value=10, max_value=20
)
assert caplog.record_tuples == [
("rest_framework.fields", logging.WARNING, "max_value in DecimalField should be Decimal type."),
("rest_framework.fields", logging.WARNING, "min_value in DecimalField should be Decimal type.")
]
class TestAllowEmptyStrDecimalFieldWithValidators(FieldValues): class TestAllowEmptyStrDecimalFieldWithValidators(FieldValues):
""" """
@ -1505,12 +1561,12 @@ class TestTZWithDateTimeField(FieldValues):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
# use class setup method, as class-level attribute will still be evaluated even if test is skipped # use class setup method, as class-level attribute will still be evaluated even if test is skipped
kolkata = pytz.timezone('Asia/Kolkata') kolkata = ZoneInfo('Asia/Kolkata')
cls.valid_inputs = { cls.valid_inputs = {
'2016-12-19T10:00:00': kolkata.localize(datetime.datetime(2016, 12, 19, 10)), '2016-12-19T10:00:00': datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
'2016-12-19T10:00:00+05:30': kolkata.localize(datetime.datetime(2016, 12, 19, 10)), '2016-12-19T10:00:00+05:30': datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
datetime.datetime(2016, 12, 19, 10): kolkata.localize(datetime.datetime(2016, 12, 19, 10)), datetime.datetime(2016, 12, 19, 10): datetime.datetime(2016, 12, 19, 10, tzinfo=kolkata),
} }
cls.invalid_inputs = {} cls.invalid_inputs = {}
cls.outputs = { cls.outputs = {
@ -1529,7 +1585,7 @@ class TestDefaultTZDateTimeField(TestCase):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.field = serializers.DateTimeField() cls.field = serializers.DateTimeField()
cls.kolkata = pytz.timezone('Asia/Kolkata') cls.kolkata = ZoneInfo('Asia/Kolkata')
def assertUTC(self, tzinfo): def assertUTC(self, tzinfo):
""" """
@ -1551,18 +1607,17 @@ class TestDefaultTZDateTimeField(TestCase):
self.assertUTC(self.field.default_timezone()) self.assertUTC(self.field.default_timezone())
@pytest.mark.skipif(pytz is None, reason='pytz not installed')
@override_settings(TIME_ZONE='UTC', USE_TZ=True) @override_settings(TIME_ZONE='UTC', USE_TZ=True)
class TestCustomTimezoneForDateTimeField(TestCase): class TestCustomTimezoneForDateTimeField(TestCase):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
cls.kolkata = pytz.timezone('Asia/Kolkata') cls.kolkata = ZoneInfo('Asia/Kolkata')
cls.date_format = '%d/%m/%Y %H:%M' cls.date_format = '%d/%m/%Y %H:%M'
def test_should_render_date_time_in_default_timezone(self): def test_should_render_date_time_in_default_timezone(self):
field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format) field = serializers.DateTimeField(default_timezone=self.kolkata, format=self.date_format)
dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=pytz.utc) dt = datetime.datetime(2018, 2, 8, 14, 15, 16, tzinfo=ZoneInfo("UTC"))
with override(self.kolkata): with override(self.kolkata):
rendered_date = field.to_representation(dt) rendered_date = field.to_representation(dt)
@ -1572,6 +1627,33 @@ class TestCustomTimezoneForDateTimeField(TestCase):
assert rendered_date == rendered_date_in_timezone assert rendered_date == rendered_date_in_timezone
@pytest.mark.skipif(pytz is None, reason="As Django 4.0 has deprecated pytz, this test should eventually be able to get removed.")
class TestPytzNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
"""
Invalid values for `DateTimeField` with datetime in DST shift (non-existing or ambiguous) and timezone with DST.
Timezone America/New_York has DST shift from 2017-03-12T02:00:00 to 2017-03-12T03:00:00 and
from 2017-11-05T02:00:00 to 2017-11-05T01:00:00 in 2017.
"""
valid_inputs = {}
invalid_inputs = {
'2017-03-12T02:30:00': ['Invalid datetime for the timezone "America/New_York".'],
'2017-11-05T01:30:00': ['Invalid datetime for the timezone "America/New_York".']
}
outputs = {}
if pytz:
class MockTimezone(pytz.BaseTzInfo):
@staticmethod
def localize(value, is_dst):
raise pytz.InvalidTimeError()
def __str__(self):
return 'America/New_York'
field = serializers.DateTimeField(default_timezone=MockTimezone())
@patch('rest_framework.utils.timezone.datetime_ambiguous', return_value=True)
class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues): class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
""" """
Invalid values for `DateTimeField` with datetime in DST shift (non-existing or ambiguous) and timezone with DST. Invalid values for `DateTimeField` with datetime in DST shift (non-existing or ambiguous) and timezone with DST.
@ -1585,15 +1667,11 @@ class TestNaiveDayLightSavingTimeTimeZoneDateTimeField(FieldValues):
} }
outputs = {} outputs = {}
class MockTimezone(pytz.BaseTzInfo): class MockZoneInfoTimezone(datetime.tzinfo):
@staticmethod
def localize(value, is_dst):
raise pytz.InvalidTimeError()
def __str__(self): def __str__(self):
return 'America/New_York' return 'America/New_York'
field = serializers.DateTimeField(default_timezone=MockTimezone()) field = serializers.DateTimeField(default_timezone=MockZoneInfoTimezone())
class TestTimeField(FieldValues): class TestTimeField(FieldValues):
@ -1797,6 +1875,54 @@ class TestChoiceField(FieldValues):
field.run_validation(2) field.run_validation(2)
assert exc_info.value.detail == ['"2" is not a valid choice.'] assert exc_info.value.detail == ['"2" is not a valid choice.']
def test_integer_choices(self):
class ChoiceCase(IntegerChoices):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "1"),
(ChoiceCase.second, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
choices = [
(ChoiceCase.first.value, "1"),
(ChoiceCase.second.value, "2")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(1) == 1
assert field.run_validation(ChoiceCase.first) == 1
assert field.run_validation("1") == 1
def test_text_choices(self):
class ChoiceCase(TextChoices):
first = auto()
second = auto()
# Enum validate
choices = [
(ChoiceCase.first, "first"),
(ChoiceCase.second, "second")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(ChoiceCase.first) == "first"
assert field.run_validation("first") == "first"
choices = [
(ChoiceCase.first.value, "first"),
(ChoiceCase.second.value, "second")
]
field = serializers.ChoiceField(choices=choices)
assert field.run_validation(ChoiceCase.first) == "first"
assert field.run_validation("first") == "first"
class TestChoiceFieldWithType(FieldValues): class TestChoiceFieldWithType(FieldValues):
""" """

View File

@ -324,6 +324,13 @@ class TestSimpleMetadataFieldInfo(TestCase):
) )
assert 'choices' not in field_info assert 'choices' not in field_info
def test_decimal_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.DecimalField(max_digits=18, decimal_places=4))
assert field_info['type'] == 'decimal'
assert field_info['max_digits'] == 18
assert field_info['decimal_places'] == 4
class TestModelSerializerMetadata(TestCase): class TestModelSerializerMetadata(TestCase):
def test_read_only_primary_key_related_field(self): def test_read_only_primary_key_related_field(self):

View File

@ -10,7 +10,6 @@ import decimal
import json # noqa import json # noqa
import sys import sys
import tempfile import tempfile
from collections import OrderedDict
import django import django
import pytest import pytest
@ -762,7 +761,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')]) expected = {1: 'Red Color', 2: 'Yellow Color', 3: 'Green Color'}
self.assertEqual(serializer.fields['color'].choices, expected) self.assertEqual(serializer.fields['color'].choices, expected)
def test_custom_display_value(self): def test_custom_display_value(self):
@ -778,7 +777,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')]) expected = {1: 'My Red Color', 2: 'My Yellow Color', 3: 'My Green Color'}
self.assertEqual(serializer.fields['color'].choices, expected) self.assertEqual(serializer.fields['color'].choices, expected)

View File

@ -632,6 +632,24 @@ class CursorPaginationTestsMixin:
ordering = self.pagination.get_ordering(request, [], MockView()) ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',) assert ordering == ('created',)
def test_use_with_ordering_filter_without_ordering_default_value(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created']
request = Request(factory.get('/'))
ordering = self.pagination.get_ordering(request, [], MockView())
# it gets the value of `ordering` provided by CursorPagination
assert ordering == ('created',)
request = Request(factory.get('/', {'ordering': 'username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',)
request = Request(factory.get('/', {'ordering': 'invalid'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
def test_cursor_pagination(self): def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/') (previous, current, next, previous_url, next_url) = self.get_pages('/')
@ -951,17 +969,24 @@ class TestCursorPagination(CursorPaginationTestsMixin):
def __init__(self, items): def __init__(self, items):
self.items = items self.items = items
def filter(self, created__gt=None, created__lt=None): def filter(self, q):
q_args = dict(q.deconstruct()[1])
if not q_args:
# django 3.0.x artifact
q_args = dict(q.deconstruct()[2])
created__gt = q_args.get('created__gt')
created__lt = q_args.get('created__lt')
if created__gt is not None: if created__gt is not None:
return MockQuerySet([ return MockQuerySet([
item for item in self.items item for item in self.items
if item.created > int(created__gt) if item.created is None or item.created > int(created__gt)
]) ])
assert created__lt is not None assert created__lt is not None
return MockQuerySet([ return MockQuerySet([
item for item in self.items item for item in self.items
if item.created < int(created__lt) if item.created is None or item.created < int(created__lt)
]) ])
def order_by(self, *ordering): def order_by(self, *ordering):
@ -1080,6 +1105,127 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
return (previous, current, next, previous_url, next_url) return (previous, current, next, previous_url, next_url)
class NullableCursorPaginationModel(models.Model):
created = models.IntegerField(null=True)
class TestCursorPaginationWithNulls(TestCase):
"""
Unit tests for `pagination.CursorPagination` with ordering on a nullable field.
"""
def setUp(self):
class ExamplePagination(pagination.CursorPagination):
page_size = 1
ordering = 'created'
self.pagination = ExamplePagination()
data = [
None, None, 3, 4
]
for idx in data:
NullableCursorPaginationModel.objects.create(created=idx)
self.queryset = NullableCursorPaginationModel.objects.all()
get_pages = TestCursorPagination.get_pages
def test_ascending(self):
"""Test paginating one row at a time, current should go 1, 2, 3, 4, 3, 2, 1."""
(previous, current, next, previous_url, next_url) = self.get_pages('/')
assert previous is None
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None]
assert current == [None]
assert next == [3]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [3] # [None] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L789
assert current == [3]
assert next == [4]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [3]
assert current == [4]
assert next is None
assert next_url is None
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [3]
assert next == [4]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [None]
assert next == [None] # [3] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous is None
assert current == [None]
assert next == [None]
def test_descending(self):
"""Test paginating one row at a time, current should go 4, 3, 2, 1, 2, 3, 4."""
self.pagination.ordering = ('-created',)
(previous, current, next, previous_url, next_url) = self.get_pages('/')
assert previous is None
assert current == [4]
assert next == [3]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None] # [4] paging artifact
assert current == [3]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None] # [3] paging artifact
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(next_url)
assert previous == [None]
assert current == [None]
assert next is None
assert next_url is None
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [3]
assert current == [None]
assert next == [None]
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous == [None]
assert current == [3]
assert next == [3] # [4] paging artifact documented at https://github.com/ddelange/django-rest-framework/blob/3.14.0/rest_framework/pagination.py#L731
# skip back artifact
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
(previous, current, next, previous_url, next_url) = self.get_pages(previous_url)
assert previous is None
assert current == [4]
assert next == [3]
def test_get_displayed_page_numbers(): def test_get_displayed_page_numbers():
""" """
Test our contextual page display function. Test our contextual page display function.

View File

@ -342,6 +342,142 @@ class TestSlugRelatedField(APISimpleTestCase):
field.to_internal_value(self.instance.name) field.to_internal_value(self.instance.name)
class TestNestedSlugRelatedField(APISimpleTestCase):
def setUp(self):
self.queryset = MockQueryset([
MockObject(
pk=1, name='foo', nested=MockObject(
pk=2, name='bar', nested=MockObject(
pk=7, name="foobar"
)
)
),
MockObject(
pk=3, name='hello', nested=MockObject(
pk=4, name='world', nested=MockObject(
pk=8, name="helloworld"
)
)
),
MockObject(
pk=5, name='harry', nested=MockObject(
pk=6, name='potter', nested=MockObject(
pk=9, name="harrypotter"
)
)
)
])
self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField(
slug_field='name', queryset=self.queryset
)
self.nested_field = serializers.SlugRelatedField(
slug_field='nested__name', queryset=self.queryset
)
self.nested_nested_field = serializers.SlugRelatedField(
slug_field='nested__nested__name', queryset=self.queryset
)
# testing nested inside nested relations
def test_slug_related_nested_nested_lookup_exists(self):
instance = self.nested_nested_field.to_internal_value(
self.instance.nested.nested.name
)
assert instance is self.instance
def test_slug_related_nested_nested_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_nested_field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == \
'Object with nested__nested__name=doesnotexist does not exist.'
def test_slug_related_nested_nested_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_nested_field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_nested_nested_representation(self):
representation =\
self.nested_nested_field.to_representation(self.instance)
assert representation == self.instance.nested.nested.name
def test_nested_nested_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='nested__nested__name')
field.to_internal_value(self.instance.nested.nested.name)
# testing nested relations
def test_slug_related_nested_lookup_exists(self):
instance = \
self.nested_field.to_internal_value(self.instance.nested.name)
assert instance is self.instance
def test_slug_related_nested_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == 'Object with nested__name=doesnotexist does not exist.'
def test_slug_related_nested_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.nested_field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_nested_representation(self):
representation = self.nested_field.to_representation(self.instance)
assert representation == self.instance.nested.name
def test_nested_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='nested__name')
field.to_internal_value(self.instance.nested.name)
# testing non-nested relations
def test_slug_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.name)
assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
def test_representation(self):
representation = self.field.to_representation(self.instance)
assert representation == self.instance.name
def test_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
def get_queryset(self):
return qs
field = NoQuerySetSlugRelatedField(slug_field='name')
field.to_internal_value(self.instance.name)
class TestManyRelatedField(APISimpleTestCase): class TestManyRelatedField(APISimpleTestCase):
def setUp(self): def setUp(self):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')

View File

@ -1,5 +1,4 @@
import re import re
from collections import OrderedDict
from collections.abc import MutableMapping from collections.abc import MutableMapping
import pytest import pytest
@ -457,12 +456,12 @@ class CacheRenderTest(TestCase):
class TestJSONIndentationStyles: class TestJSONIndentationStyles:
def test_indented(self): def test_indented(self):
renderer = JSONRenderer() renderer = JSONRenderer()
data = OrderedDict([('a', 1), ('b', 2)]) data = {"a": 1, "b": 2}
assert renderer.render(data) == b'{"a":1,"b":2}' assert renderer.render(data) == b'{"a":1,"b":2}'
def test_compact(self): def test_compact(self):
renderer = JSONRenderer() renderer = JSONRenderer()
data = OrderedDict([('a', 1), ('b', 2)]) data = {"a": 1, "b": 2}
context = {'indent': 4} context = {'indent': 4}
assert ( assert (
renderer.render(data, renderer_context=context) == renderer.render(data, renderer_context=context) ==
@ -472,7 +471,7 @@ class TestJSONIndentationStyles:
def test_long_form(self): def test_long_form(self):
renderer = JSONRenderer() renderer = JSONRenderer()
renderer.compact = False renderer.compact = False
data = OrderedDict([('a', 1), ('b', 2)]) data = {"a": 1, "b": 2}
assert renderer.render(data) == b'{"a": 1, "b": 2}' assert renderer.render(data) == b'{"a": 1, "b": 2}'
@ -634,6 +633,9 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
class AuthExampleViewSet(ExampleViewSet): class AuthExampleViewSet(ExampleViewSet):
permission_classes = [permissions.IsAuthenticated] permission_classes = [permissions.IsAuthenticated]
class SimpleSerializer(serializers.Serializer):
name = serializers.CharField()
router = SimpleRouter() router = SimpleRouter()
router.register('examples', ExampleViewSet, basename='example') router.register('examples', ExampleViewSet, basename='example')
router.register('auth-examples', AuthExampleViewSet, basename='auth-example') router.register('auth-examples', AuthExampleViewSet, basename='auth-example')
@ -641,6 +643,62 @@ class BrowsableAPIRendererTests(URLPatternsTestCase):
def setUp(self): def setUp(self):
self.renderer = BrowsableAPIRenderer() self.renderer = BrowsableAPIRenderer()
self.renderer.accepted_media_type = ''
self.renderer.renderer_context = {}
def test_render_form_for_serializer(self):
with self.subTest('Serializer'):
serializer = BrowsableAPIRendererTests.SimpleSerializer(data={'name': 'Name'})
form = self.renderer.render_form_for_serializer(serializer)
assert isinstance(form, str), 'Must return form for serializer'
with self.subTest('ListSerializer'):
list_serializer = BrowsableAPIRendererTests.SimpleSerializer(data=[{'name': 'Name'}], many=True)
form = self.renderer.render_form_for_serializer(list_serializer)
assert form is None, 'Must not return form for list serializer'
def test_get_raw_data_form(self):
with self.subTest('Serializer'):
class DummyGenericViewsetLike(APIView):
def get_serializer(self, **kwargs):
return BrowsableAPIRendererTests.SimpleSerializer(**kwargs)
def get(self, request):
response = Response()
response.view = self
return response
post = get
view = DummyGenericViewsetLike.as_view()
_request = APIRequestFactory().get('/')
request = Request(_request)
response = view(_request)
view = response.view
raw_data_form = self.renderer.get_raw_data_form({'name': 'Name'}, view, 'POST', request)
assert raw_data_form['_content'].initial == '{\n "name": ""\n}'
with self.subTest('ListSerializer'):
class DummyGenericViewsetLike(APIView):
def get_serializer(self, **kwargs):
return BrowsableAPIRendererTests.SimpleSerializer(many=True, **kwargs) # returns ListSerializer
def get(self, request):
response = Response()
response.view = self
return response
post = get
view = DummyGenericViewsetLike.as_view()
_request = APIRequestFactory().get('/')
request = Request(_request)
response = view(_request)
view = response.view
raw_data_form = self.renderer.get_raw_data_form([{'name': 'Name'}], view, 'POST', request)
assert raw_data_form['_content'].initial == '[\n {\n "name": ""\n }\n]'
def test_get_description_returns_empty_string_for_401_and_403_statuses(self): def test_get_description_returns_empty_string_for_401_and_403_statuses(self):
assert self.renderer.get_description({}, status_code=401) == '' assert self.renderer.get_description({}, status_code=401) == ''

View File

@ -2,6 +2,7 @@ import inspect
import pickle import pickle
import re import re
import sys import sys
import unittest
from collections import ChainMap from collections import ChainMap
from collections.abc import Mapping from collections.abc import Mapping
@ -764,3 +765,84 @@ class Test8301Regression:
assert (s.data | {}).__class__ == s.data.__class__ assert (s.data | {}).__class__ == s.data.__class__
assert ({} | s.data).__class__ == s.data.__class__ assert ({} | s.data).__class__ == s.data.__class__
class TestSetValueMethod:
# Serializer.set_value() modifies the first parameter in-place.
s = serializers.Serializer()
def test_no_keys(self):
ret = {'a': 1}
self.s.set_value(ret, [], {'b': 2})
assert ret == {'a': 1, 'b': 2}
def test_one_key(self):
ret = {'a': 1}
self.s.set_value(ret, ['x'], 2)
assert ret == {'a': 1, 'x': 2}
def test_nested_key(self):
ret = {'a': 1}
self.s.set_value(ret, ['x', 'y'], 2)
assert ret == {'a': 1, 'x': {'y': 2}}
class MyClass(models.Model):
name = models.CharField(max_length=100)
value = models.CharField(max_length=100, blank=True)
app_label = "test"
@property
def is_valid(self):
return self.name == 'valid'
class MyClassSerializer(serializers.ModelSerializer):
class Meta:
model = MyClass
fields = ('id', 'name', 'value')
def validate_value(self, value):
if value and not self.instance.is_valid:
raise serializers.ValidationError(
'Status cannot be set for invalid instance')
return value
class TestMultipleObjectsValidation(unittest.TestCase):
def setUp(self):
self.objs = [
MyClass(name='valid'),
MyClass(name='invalid'),
MyClass(name='other'),
]
def test_multiple_objects_are_validated_separately(self):
serializer = MyClassSerializer(
data=[{'value': 'set', 'id': instance.id} for instance in
self.objs],
instance=self.objs,
many=True,
partial=True,
)
assert not serializer.is_valid()
assert serializer.errors == [
{},
{'value': ['Status cannot be set for invalid instance']},
{'value': ['Status cannot be set for invalid instance']}
]
def test_exception_raised_when_data_and_instance_length_different(self):
with self.assertRaises(AssertionError):
MyClassSerializer(
data=[{'value': 'set', 'id': instance.id} for instance in
self.objs],
instance=self.objs[:-1],
many=True,
partial=True,
)

View File

@ -1,4 +1,5 @@
import datetime import datetime
from unittest.mock import MagicMock
import pytest import pytest
from django.db import DataError, models from django.db import DataError, models
@ -787,3 +788,13 @@ class ValidatorsTests(TestCase):
validator.filter_queryset( validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name='' attrs=None, queryset=None, field_name='', date_field_name=''
) )
def test_equality_operator(self):
mock_queryset = MagicMock()
validator = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
validator2 = BaseUniqueForValidator(queryset=mock_queryset, field='foo',
date_field='bar')
assert validator == validator2
validator2.date_field = "bar2"
assert validator != validator2

View File

@ -152,6 +152,8 @@ class TestURLReversing(URLPatternsTestCase, APITestCase):
path('v1/', include((included, 'v1'), namespace='v1')), path('v1/', include((included, 'v1'), namespace='v1')),
path('another/', dummy_view, name='another'), path('another/', dummy_view, name='another'),
re_path(r'^(?P<version>[v1|v2]+)/another/$', dummy_view, name='another'), re_path(r'^(?P<version>[v1|v2]+)/another/$', dummy_view, name='another'),
re_path(r'^(?P<foo>.+)/unversioned/$', dummy_view, name='unversioned'),
] ]
def test_reverse_unversioned(self): def test_reverse_unversioned(self):
@ -198,6 +200,14 @@ class TestURLReversing(URLPatternsTestCase, APITestCase):
response = view(request) response = view(request)
assert response.data == {'url': 'http://testserver/another/'} assert response.data == {'url': 'http://testserver/another/'}
# Test fallback when kwargs is not None
request = factory.get('/v1/endpoint/')
request.versioning_scheme = scheme()
request.version = 'v1'
reversed_url = reverse('unversioned', request=request, kwargs={'foo': 'bar'})
assert reversed_url == 'http://testserver/bar/unversioned/'
def test_reverse_namespace_versioning(self): def test_reverse_namespace_versioning(self):
class FakeResolverMatch(ResolverMatch): class FakeResolverMatch(ResolverMatch):
namespace = 'v1' namespace = 'v1'
@ -262,7 +272,7 @@ class TestInvalidVersion:
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
class TestAllowedAndDefaultVersion: class TestAcceptHeaderAllowedAndDefaultVersion:
def test_missing_without_default(self): def test_missing_without_default(self):
scheme = versioning.AcceptHeaderVersioning scheme = versioning.AcceptHeaderVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
@ -308,6 +318,97 @@ class TestAllowedAndDefaultVersion:
assert response.data == {'version': 'v2'} assert response.data == {'version': 'v2'}
class TestNamespaceAllowedAndDefaultVersion:
def test_no_namespace_without_default(self):
class FakeResolverMatch:
namespace = None
scheme = versioning.NamespaceVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_no_namespace_with_default(self):
class FakeResolverMatch:
namespace = None
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
def test_no_match_without_default(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_no_match_with_default(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
def test_with_default(self):
class FakeResolverMatch:
namespace = 'v1'
scheme = versioning.NamespaceVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v1'}
def test_no_match_without_default_but_none_allowed(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedWithNoneVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': None}
def test_no_match_with_default_and_none_allowed(self):
class FakeResolverMatch:
namespace = 'no_match'
scheme = versioning.NamespaceVersioning
view = AllowedWithNoneAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/')
request.resolver_match = FakeResolverMatch
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase): class TestHyperlinkedRelatedField(URLPatternsTestCase, APITestCase):
included = [ included = [
path('namespaced/<int:pk>/', dummy_pk_view, name='namespaced'), path('namespaced/<int:pk>/', dummy_pk_view, name='namespaced'),

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
from functools import wraps from functools import wraps
import pytest import pytest
@ -261,11 +260,11 @@ class GetExtraActionUrlMapTests(TestCase):
response = self.client.get('/api/actions/') response = self.client.get('/api/actions/')
view = response.view view = response.view
expected = OrderedDict([ expected = {
('Custom list action', 'http://testserver/api/actions/custom_list_action/'), 'Custom list action': 'http://testserver/api/actions/custom_list_action/',
('List action', 'http://testserver/api/actions/list_action/'), 'List action': 'http://testserver/api/actions/list_action/',
('Wrapped list action', 'http://testserver/api/actions/wrapped_list_action/'), 'Wrapped list action': 'http://testserver/api/actions/wrapped_list_action/',
]) }
self.assertEqual(view.get_extra_action_url_map(), expected) self.assertEqual(view.get_extra_action_url_map(), expected)
@ -273,28 +272,28 @@ class GetExtraActionUrlMapTests(TestCase):
response = self.client.get('/api/actions/1/') response = self.client.get('/api/actions/1/')
view = response.view view = response.view
expected = OrderedDict([ expected = {
('Custom detail action', 'http://testserver/api/actions/1/custom_detail_action/'), 'Custom detail action': 'http://testserver/api/actions/1/custom_detail_action/',
('Detail action', 'http://testserver/api/actions/1/detail_action/'), 'Detail action': 'http://testserver/api/actions/1/detail_action/',
('Wrapped detail action', 'http://testserver/api/actions/1/wrapped_detail_action/'), 'Wrapped detail action': 'http://testserver/api/actions/1/wrapped_detail_action/',
# "Unresolvable detail action" excluded, since it's not resolvable # "Unresolvable detail action" excluded, since it's not resolvable
]) }
self.assertEqual(view.get_extra_action_url_map(), expected) self.assertEqual(view.get_extra_action_url_map(), expected)
def test_uninitialized_view(self): def test_uninitialized_view(self):
self.assertEqual(ActionViewSet().get_extra_action_url_map(), OrderedDict()) self.assertEqual(ActionViewSet().get_extra_action_url_map(), {})
def test_action_names(self): def test_action_names(self):
# Action 'name' and 'suffix' kwargs should be respected # Action 'name' and 'suffix' kwargs should be respected
response = self.client.get('/api/names/1/') response = self.client.get('/api/names/1/')
view = response.view view = response.view
expected = OrderedDict([ expected = {
('Custom Name', 'http://testserver/api/names/1/named_action/'), 'Custom Name': 'http://testserver/api/names/1/named_action/',
('Action Names Custom Suffix', 'http://testserver/api/names/1/suffixed_action/'), 'Action Names Custom Suffix': 'http://testserver/api/names/1/suffixed_action/',
('Unnamed action', 'http://testserver/api/names/1/unnamed_action/'), 'Unnamed action': 'http://testserver/api/names/1/unnamed_action/',
]) }
self.assertEqual(view.get_extra_action_url_map(), expected) self.assertEqual(view.get_extra_action_url_map(), expected)

View File

@ -1,3 +1,5 @@
from operator import attrgetter
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from django.urls import NoReverseMatch from django.urls import NoReverseMatch
@ -26,7 +28,7 @@ class MockQueryset:
def get(self, **lookup): def get(self, **lookup):
for item in self.items: for item in self.items:
if all([ if all([
getattr(item, key, None) == value attrgetter(key.replace('__', '.'))(item) == value
for key, value in lookup.items() for key, value in lookup.items()
]): ]):
return item return item
@ -39,6 +41,7 @@ class BadType:
will raise a `TypeError`, as occurs in Django when making will raise a `TypeError`, as occurs in Django when making
queryset lookups with an incorrect type for the lookup value. queryset lookups with an incorrect type for the lookup value.
""" """
def __eq__(self): def __eq__(self):
raise TypeError() raise TypeError()

View File

@ -21,7 +21,7 @@ deps =
django32: Django>=3.2,<4.0 django32: Django>=3.2,<4.0
django40: Django>=4.0,<4.1 django40: Django>=4.0,<4.1
django41: Django>=4.1,<4.2 django41: Django>=4.1,<4.2
django42: Django>=4.2b1,<5.0 django42: Django>=4.2,<5.0
djangomain: https://github.com/django/django/archive/main.tar.gz djangomain: https://github.com/django/django/archive/main.tar.gz
-rrequirements/requirements-testing.txt -rrequirements/requirements-testing.txt
-rrequirements/requirements-optionals.txt -rrequirements/requirements-optionals.txt