diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a6e554b9..83fe0b714 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - id: check-symlinks - id: check-toml - repo: https://github.com/pycqa/isort - rev: 5.8.0 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/PyCQA/flake8 diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 0d14cac46..2b717b6e6 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -226,7 +226,7 @@ Individual fields on a serializer can include validators, by declaring them on t raise serializers.ValidationError('Not a multiple of ten') class GameRecord(serializers.Serializer): - score = IntegerField(validators=[multiple_of_ten]) + score = serializers.IntegerField(validators=[multiple_of_ten]) ... Serializer classes can also include reusable validators that are applied to the complete set of field data. These validators are included by declaring them on an inner `Meta` class, like so: diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md index 74789e337..f9b6c5e9a 100644 --- a/docs/tutorial/6-viewsets-and-routers.md +++ b/docs/tutorial/6-viewsets-and-routers.md @@ -27,6 +27,7 @@ Here we've used the `ReadOnlyModelViewSet` class to automatically provide the de Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. from rest_framework import permissions + from rest_framework import renderers from rest_framework.decorators import action from rest_framework.response import Response diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 09f111102..bc20fcaa3 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -144,17 +144,30 @@ class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = _('Invalid input.') default_code = 'invalid' + default_params = {} - def __init__(self, detail=None, code=None): + def __init__(self, detail=None, code=None, params=None): if detail is None: detail = self.default_detail if code is None: code = self.default_code + if params is None: + params = self.default_params # For validation failures, we may collect many errors together, # so the details should always be coerced to a list if not already. - if isinstance(detail, tuple): - detail = list(detail) + if isinstance(detail, str): + detail = [detail % params] + elif isinstance(detail, ValidationError): + detail = detail.detail + elif isinstance(detail, (list, tuple)): + final_detail = [] + for detail_item in detail: + if isinstance(detail_item, ValidationError): + final_detail += detail_item.detail + else: + final_detail += [detail_item % params if isinstance(detail_item, str) else detail_item] + detail = final_detail elif not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1c6425596..613bd325a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -356,6 +356,10 @@ class Field: messages.update(error_messages or {}) self.error_messages = messages + # Allow generic typing checking for fields. + def __class_getitem__(cls, *args, **kwargs): + return cls + def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda4..167303321 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,10 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # Allow generic typing checking for generic views. + def __class_getitem__(cls, *args, **kwargs): + return cls + def get_queryset(self): """ Get the list of items for this view. diff --git a/rest_framework/request.py b/rest_framework/request.py index 194be5f6d..93109226d 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -186,6 +186,10 @@ class Request: self.method, self.get_full_path()) + # Allow generic typing checking for requests. + def __class_getitem__(cls, *args, **kwargs): + return cls + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/response.py b/rest_framework/response.py index 495423734..6e756544c 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -46,6 +46,10 @@ class Response(SimpleTemplateResponse): for name, value in headers.items(): self[name] = value + # Allow generic typing checking for responses. + def __class_getitem__(cls, *args, **kwargs): + return cls + @property def rendered_content(self): renderer = getattr(self, 'accepted_renderer', None) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eae6a0b2e..e27f8a47c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1398,6 +1398,23 @@ class ModelSerializer(Serializer): return extra_kwargs + def get_unique_together_constraints(self, model): + """ + Returns iterator of (fields, queryset), each entry describes an unique together + constraint on `fields` in `queryset`. + """ + for parent_class in [model] + list(model._meta.parents): + for unique_together in parent_class._meta.unique_together: + yield unique_together, model._default_manager + for constraint in parent_class._meta.constraints: + if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: + yield ( + constraint.fields, + model._default_manager + if constraint.condition is None + else model._default_manager.filter(constraint.condition) + ) + def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ Return any additional field options that need to be included as a @@ -1426,12 +1443,11 @@ class ModelSerializer(Serializer): unique_constraint_names -= {None} - # Include each of the `unique_together` field names, + # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for parent_class in [model] + list(model._meta.parents): - for unique_together_list in parent_class._meta.unique_together: - if set(field_names).issuperset(unique_together_list): - unique_constraint_names |= set(unique_together_list) + for unique_together_list, queryset in self.get_unique_together_constraints(model): + if set(field_names).issuperset(unique_together_list): + unique_constraint_names |= set(unique_together_list) # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1526,11 +1542,6 @@ class ModelSerializer(Serializer): """ Determine a default set of validators for any unique_together constraints. """ - model_class_inheritance_tree = ( - [self.Meta.model] + - list(self.Meta.model._meta.parents) - ) - # The field names we're passing though here only include fields # which may map onto a model field. Any dotted field name lookups # cannot map to a field, and must be a traversal, so we're not @@ -1556,34 +1567,33 @@ class ModelSerializer(Serializer): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for parent_class in model_class_inheritance_tree: - for unique_together in parent_class._meta.unique_together: - # Skip if serializer does not map to all unique together sources - if not set(source_map).issuperset(unique_together): - continue + for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): + # Skip if serializer does not map to all unique together sources + if not set(source_map).issuperset(unique_together): + continue - for source in unique_together: - assert len(source_map[source]) == 1, ( - "Unable to create `UniqueTogetherValidator` for " - "`{model}.{field}` as `{serializer}` has multiple " - "fields ({fields}) that map to this model field. " - "Either remove the extra fields, or override " - "`Meta.validators` with a `UniqueTogetherValidator` " - "using the desired field names." - .format( - model=self.Meta.model.__name__, - serializer=self.__class__.__name__, - field=source, - fields=', '.join(source_map[source]), - ) + for source in unique_together: + assert len(source_map[source]) == 1, ( + "Unable to create `UniqueTogetherValidator` for " + "`{model}.{field}` as `{serializer}` has multiple " + "fields ({fields}) that map to this model field. " + "Either remove the extra fields, or override " + "`Meta.validators` with a `UniqueTogetherValidator` " + "using the desired field names." + .format( + model=self.Meta.model.__name__, + serializer=self.__class__.__name__, + field=source, + fields=', '.join(source_map[source]), ) - - field_names = tuple(source_map[f][0] for f in unique_together) - validator = UniqueTogetherValidator( - queryset=parent_class._default_manager, - fields=field_names ) - validators.append(validator) + + field_names = tuple(source_map[f][0] for f in unique_together) + validator = UniqueTogetherValidator( + queryset=queryset, + fields=field_names + ) + validators.append(validator) return validators def get_unique_for_date_validators(self): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 7e8e8f046..fc63f96fe 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -62,6 +62,29 @@ def get_detail_view_name(model): } +def get_unique_validators(field_name, model_field): + """ + Returns a list of UniqueValidators that should be applied to the field. + """ + field_set = set([field_name]) + conditions = { + c.condition + for c in model_field.model._meta.constraints + if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set + } + if getattr(model_field, 'unique', False): + conditions.add(None) + if not conditions: + return + unique_error_message = get_unique_error_message(model_field) + queryset = model_field.model._default_manager + for condition in conditions: + yield UniqueValidator( + queryset=queryset if condition is None else queryset.filter(condition), + message=unique_error_message + ) + + def get_field_kwargs(field_name, model_field): """ Creates a default instance of a basic non-relational field. @@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field): if not isinstance(validator, validators.MinLengthValidator) ] - if getattr(model_field, 'unique', False): - validator = UniqueValidator( - queryset=model_field.model._default_manager, - message=get_unique_error_message(model_field)) - validator_kwarg.append(validator) + validator_kwarg += get_unique_validators(field_name, model_field) if validator_kwarg: kwargs['validators'] = validator_kwarg diff --git a/setup.py b/setup.py index d00470268..9a5b272f3 100755 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ setup( 'Framework :: Django :: 3.2', 'Framework :: Django :: 4.0', 'Framework :: Django :: 4.1', + 'Framework :: Django :: 4.2', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', diff --git a/tests/test_fields.py b/tests/test_fields.py index 512f3f789..56e2a45ba 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2,6 +2,7 @@ import datetime import math import os import re +import sys import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal @@ -625,6 +626,15 @@ class Test5087Regression: assert field.root is parent +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_field_is_subscriptable(self): + assert serializers.Field is serializers.Field["foo"] + + # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_generics.py b/tests/test_generics.py index 78dc5afb6..9990389c9 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.db import models from django.http import Http404 @@ -698,3 +700,26 @@ class TestSerializer(TestCase): serializer = response.serializer assert serializer.context is context + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_genericview_is_subscriptable(self): + assert generics.GenericAPIView is generics.GenericAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_listview_is_subscriptable(self): + assert generics.ListAPIView is generics.ListAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_instanceview_is_subscriptable(self): + assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"] diff --git a/tests/test_request.py b/tests/test_request.py index 8c18aea9e..e37aa7dda 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -3,6 +3,7 @@ Tests for content parsing, and form-overloaded content parsing. """ import copy import os.path +import sys import tempfile import pytest @@ -352,3 +353,12 @@ class TestDeepcopy(TestCase): def test_deepcopy_works(self): request = Request(factory.get('/', secure=False)) copy.deepcopy(request) + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_request_is_subscriptable(self): + assert Request is Request["foo"] diff --git a/tests/test_response.py b/tests/test_response.py index 0d5528dc9..cab19a1eb 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,3 +1,6 @@ +import sys + +import pytest from django.test import TestCase, override_settings from django.urls import include, path, re_path @@ -283,3 +286,12 @@ class Issue807Tests(TestCase): self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_response_is_subscriptable(self): + assert Response is Response["foo"] diff --git a/tests/test_reverse.py b/tests/test_reverse.py index b26b448c9..b89f5be43 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -3,6 +3,7 @@ from django.urls import NoReverseMatch, path from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory +from rest_framework.versioning import BaseVersioning factory = APIRequestFactory() @@ -16,7 +17,7 @@ urlpatterns = [ ] -class MockVersioningScheme: +class MockVersioningScheme(BaseVersioning): def __init__(self, raise_error=False): self.raise_error = raise_error diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py index 341c4342a..7b8b3190f 100644 --- a/tests/test_validation_error.py +++ b/tests/test_validation_error.py @@ -109,3 +109,89 @@ class TestValidationErrorConvertsTuplesToLists(TestCase): assert len(error.detail) == 2 assert str(error.detail[0]) == 'message1' assert str(error.detail[1]) == 'message2' + + +class TestValidationErrorWithDjangoStyle(TestCase): + def test_validation_error_details(self): + error = ValidationError('Invalid value: %(value)s', params={'value': '42'}) + assert str(error.detail[0]) == 'Invalid value: 42' + + def test_validation_error_details_tuple(self): + error = ValidationError( + detail=('Invalid value: %(value1)s', 'Invalid value: %(value2)s'), + params={'value1': '42', 'value2': '43'}, + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 2 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + + def test_validation_error_details_list(self): + error = ValidationError( + detail=['Invalid value: %(value1)s', 'Invalid value: %(value2)s', ], + params={'value1': '42', 'value2': '43'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 2 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + + def test_validation_error_details_validation_errors(self): + error = ValidationError( + detail=ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 1 + assert str(error.detail[0]) == 'Invalid value: 42' + + def test_validation_error_details_validation_errors_list(self): + error = ValidationError( + detail=[ + ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ValidationError( + detail='Invalid value: %(value2)s', + params={'value2': '43'}, + ), + 'Invalid value: %(value3)s' + ], + params={'value3': '44'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 3 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + assert str(error.detail[2]) == 'Invalid value: 44' + + def test_validation_error_details_validation_errors_nested_list(self): + error = ValidationError( + detail=[ + ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ValidationError( + detail=[ + 'Invalid value: %(value2)s', + ValidationError( + detail='Invalid value: %(value3)s', + params={'value3': '44'}, + ) + ], + params={'value2': '43'}, + ), + 'Invalid value: %(value4)s' + ], + params={'value4': '45'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 4 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + assert str(error.detail[2]) == 'Invalid value: 44' + assert str(error.detail[3]) == 'Invalid value: 45' diff --git a/tests/test_validators.py b/tests/test_validators.py index 39490ac86..35fef6f26 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -464,6 +464,106 @@ class TestUniquenessTogetherValidation(TestCase): assert queryset.called_with == {'race_name': 'bar', 'position': 1} +class UniqueConstraintModel(models.Model): + race_name = models.CharField(max_length=100) + position = models.IntegerField() + global_id = models.IntegerField() + fancy_conditions = models.IntegerField(null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + name="unique_constraint_model_global_id_uniq", + fields=('global_id',), + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_1_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__lte=1) + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_3_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__gte=3) + ), + models.UniqueConstraint( + name="unique_constraint_model_together_uniq", + fields=('race_name', 'position'), + condition=models.Q(race_name='example'), + ) + ] + + +class UniqueConstraintSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintModel + fields = '__all__' + + +class TestUniqueConstraintValidation(TestCase): + def setUp(self): + self.instance = UniqueConstraintModel.objects.create( + race_name='example', + position=1, + global_id=1 + ) + UniqueConstraintModel.objects.create( + race_name='example', + position=2, + global_id=2 + ) + UniqueConstraintModel.objects.create( + race_name='other', + position=1, + global_id=3 + ) + + def test_repr(self): + serializer = UniqueConstraintSerializer() + # the order of validators isn't deterministic so delete + # fancy_conditions field that has two of them + del serializer.fields['fancy_conditions'] + expected = dedent(""" + UniqueConstraintSerializer(): + id = IntegerField(label='ID', read_only=True) + race_name = CharField(max_length=100, required=True) + position = IntegerField(required=True) + global_id = IntegerField(validators=[]) + class Meta: + validators = [, ]>, fields=('race_name', 'position'))>] + """) + assert repr(serializer) == expected + + def test_unique_together_field(self): + """ + UniqueConstraint fields and condition attributes must be passed + to UniqueTogetherValidator as fields and queryset + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validator = serializer.validators[0] + assert validator.fields == ('race_name', 'position') + assert set(validator.queryset.values_list(flat=True)) == set( + UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) + ) + + def test_single_field_uniq_validators(self): + """ + UniqueConstraint with single field must be transformed into + field's UniqueValidator + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validators = serializer.fields['global_id'].validators + assert len(validators) == 1 + assert validators[0].queryset == UniqueConstraintModel.objects + + validators = serializer.fields['fancy_conditions'].validators + assert len(validators) == 2 + ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators} + assert ids_in_qs == {frozenset([1]), frozenset([3])} + + # Tests for `UniqueForDateValidator` # ---------------------------------- diff --git a/tests/test_versioning.py b/tests/test_versioning.py index d40d54229..93f61d2be 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,6 +1,6 @@ import pytest from django.test import override_settings -from django.urls import include, path, re_path +from django.urls import ResolverMatch, include, path, re_path from rest_framework import serializers, status, versioning from rest_framework.decorators import APIView @@ -126,7 +126,7 @@ class TestRequestVersion: assert response.data == {'version': None} def test_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v1' scheme = versioning.NamespaceVersioning @@ -199,7 +199,7 @@ class TestURLReversing(URLPatternsTestCase, APITestCase): assert response.data == {'url': 'http://testserver/another/'} def test_reverse_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v1' scheme = versioning.NamespaceVersioning @@ -250,7 +250,7 @@ class TestInvalidVersion: assert response.status_code == status.HTTP_404_NOT_FOUND def test_invalid_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v3' scheme = versioning.NamespaceVersioning diff --git a/tox.ini b/tox.ini index 05cdd25dc..7027612e0 100644 --- a/tox.ini +++ b/tox.ini @@ -3,8 +3,8 @@ envlist = {py36,py37,py38,py39}-django30 {py36,py37,py38,py39}-django31 {py36,py37,py38,py39,py310}-django32 - {py38,py39,py310}-{django40,django41,djangomain} - {py311}-{django41,djangomain} + {py38,py39,py310}-{django40,django41,django42,djangomain} + {py311}-{django41,django42,djangomain} base dist docs @@ -21,6 +21,7 @@ deps = django32: Django>=3.2,<4.0 django40: Django>=4.0,<4.1 django41: Django>=4.1,<4.2 + django42: Django>=4.2b1,<5.0 djangomain: https://github.com/django/django/archive/main.tar.gz -rrequirements/requirements-testing.txt -rrequirements/requirements-optionals.txt