From 2db0c0bf0a97ce42369d5b3474d38cebd274d8ae Mon Sep 17 00:00:00 2001 From: Asif Saif Uddin Date: Thu, 19 Jan 2023 20:47:50 +0600 Subject: [PATCH 1/9] initial django 4.2a1 testing (#8846) * initial django 4.2a1 testing * django 4.2 in classifier --- setup.py | 1 + tox.ini | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d00470268..9a5b272f3 100755 --- a/setup.py +++ b/setup.py @@ -94,6 +94,7 @@ setup( 'Framework :: Django :: 3.2', 'Framework :: Django :: 4.0', 'Framework :: Django :: 4.1', + 'Framework :: Django :: 4.2', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', diff --git a/tox.ini b/tox.ini index 05cdd25dc..6fd9d8695 100644 --- a/tox.ini +++ b/tox.ini @@ -3,8 +3,8 @@ envlist = {py36,py37,py38,py39}-django30 {py36,py37,py38,py39}-django31 {py36,py37,py38,py39,py310}-django32 - {py38,py39,py310}-{django40,django41,djangomain} - {py311}-{django41,djangomain} + {py38,py39,py310}-{django40,django41,django42,djangomain} + {py311}-{django41,django42,djangomain} base dist docs @@ -21,6 +21,7 @@ deps = django32: Django>=3.2,<4.0 django40: Django>=4.0,<4.1 django41: Django>=4.1,<4.2 + django42: Django>=4.2a1,<5.0 djangomain: https://github.com/django/django/archive/main.tar.gz -rrequirements/requirements-testing.txt -rrequirements/requirements-optionals.txt From 22d206c1e0dbc03840c4d190f7eda537c0f2010a Mon Sep 17 00:00:00 2001 From: piotrszyma Date: Sat, 28 Jan 2023 12:18:58 +0100 Subject: [PATCH 2/9] Inherit from faked classes in tests to satisfy mypy (#8859) * tests: inherit FakeResolverMatcher from django.urls.ResolverMatcher in tests/test_versioning.py * tests: inherit from rest_framework.versioning.BaseVersioning in tests/test_reverse.py * fix: isort --------- Co-authored-by: Piotr Szyma --- tests/test_reverse.py | 3 ++- tests/test_versioning.py | 8 ++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/test_reverse.py b/tests/test_reverse.py index b26b448c9..b89f5be43 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -3,6 +3,7 @@ from django.urls import NoReverseMatch, path from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory +from rest_framework.versioning import BaseVersioning factory = APIRequestFactory() @@ -16,7 +17,7 @@ urlpatterns = [ ] -class MockVersioningScheme: +class MockVersioningScheme(BaseVersioning): def __init__(self, raise_error=False): self.raise_error = raise_error diff --git a/tests/test_versioning.py b/tests/test_versioning.py index d40d54229..93f61d2be 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,6 +1,6 @@ import pytest from django.test import override_settings -from django.urls import include, path, re_path +from django.urls import ResolverMatch, include, path, re_path from rest_framework import serializers, status, versioning from rest_framework.decorators import APIView @@ -126,7 +126,7 @@ class TestRequestVersion: assert response.data == {'version': None} def test_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v1' scheme = versioning.NamespaceVersioning @@ -199,7 +199,7 @@ class TestURLReversing(URLPatternsTestCase, APITestCase): assert response.data == {'url': 'http://testserver/another/'} def test_reverse_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v1' scheme = versioning.NamespaceVersioning @@ -250,7 +250,7 @@ class TestInvalidVersion: assert response.status_code == status.HTTP_404_NOT_FOUND def test_invalid_namespace_versioning(self): - class FakeResolverMatch: + class FakeResolverMatch(ResolverMatch): namespace = 'v3' scheme = versioning.NamespaceVersioning From 4abfa28e0879e2df45937ac8c7a9ffa161561955 Mon Sep 17 00:00:00 2001 From: Ehsan200 <59165228+Ehsan200@users.noreply.github.com> Date: Mon, 6 Feb 2023 15:00:48 +0330 Subject: [PATCH 3/9] feat: Add some changes to ValidationError to support django style vadation errors (#8863) --- rest_framework/exceptions.py | 19 ++++++-- tests/test_validation_error.py | 86 ++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 3 deletions(-) diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 09f111102..bc20fcaa3 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -144,17 +144,30 @@ class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST default_detail = _('Invalid input.') default_code = 'invalid' + default_params = {} - def __init__(self, detail=None, code=None): + def __init__(self, detail=None, code=None, params=None): if detail is None: detail = self.default_detail if code is None: code = self.default_code + if params is None: + params = self.default_params # For validation failures, we may collect many errors together, # so the details should always be coerced to a list if not already. - if isinstance(detail, tuple): - detail = list(detail) + if isinstance(detail, str): + detail = [detail % params] + elif isinstance(detail, ValidationError): + detail = detail.detail + elif isinstance(detail, (list, tuple)): + final_detail = [] + for detail_item in detail: + if isinstance(detail_item, ValidationError): + final_detail += detail_item.detail + else: + final_detail += [detail_item % params if isinstance(detail_item, str) else detail_item] + detail = final_detail elif not isinstance(detail, dict) and not isinstance(detail, list): detail = [detail] diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py index 341c4342a..7b8b3190f 100644 --- a/tests/test_validation_error.py +++ b/tests/test_validation_error.py @@ -109,3 +109,89 @@ class TestValidationErrorConvertsTuplesToLists(TestCase): assert len(error.detail) == 2 assert str(error.detail[0]) == 'message1' assert str(error.detail[1]) == 'message2' + + +class TestValidationErrorWithDjangoStyle(TestCase): + def test_validation_error_details(self): + error = ValidationError('Invalid value: %(value)s', params={'value': '42'}) + assert str(error.detail[0]) == 'Invalid value: 42' + + def test_validation_error_details_tuple(self): + error = ValidationError( + detail=('Invalid value: %(value1)s', 'Invalid value: %(value2)s'), + params={'value1': '42', 'value2': '43'}, + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 2 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + + def test_validation_error_details_list(self): + error = ValidationError( + detail=['Invalid value: %(value1)s', 'Invalid value: %(value2)s', ], + params={'value1': '42', 'value2': '43'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 2 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + + def test_validation_error_details_validation_errors(self): + error = ValidationError( + detail=ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 1 + assert str(error.detail[0]) == 'Invalid value: 42' + + def test_validation_error_details_validation_errors_list(self): + error = ValidationError( + detail=[ + ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ValidationError( + detail='Invalid value: %(value2)s', + params={'value2': '43'}, + ), + 'Invalid value: %(value3)s' + ], + params={'value3': '44'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 3 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + assert str(error.detail[2]) == 'Invalid value: 44' + + def test_validation_error_details_validation_errors_nested_list(self): + error = ValidationError( + detail=[ + ValidationError( + detail='Invalid value: %(value1)s', + params={'value1': '42'}, + ), + ValidationError( + detail=[ + 'Invalid value: %(value2)s', + ValidationError( + detail='Invalid value: %(value3)s', + params={'value3': '44'}, + ) + ], + params={'value2': '43'}, + ), + 'Invalid value: %(value4)s' + ], + params={'value4': '45'} + ) + assert isinstance(error.detail, list) + assert len(error.detail) == 4 + assert str(error.detail[0]) == 'Invalid value: 42' + assert str(error.detail[1]) == 'Invalid value: 43' + assert str(error.detail[2]) == 'Invalid value: 44' + assert str(error.detail[3]) == 'Invalid value: 45' From 34953774f34e2dc980c522f32000587ea7edb6a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=96zg=C3=BCr?= Date: Thu, 16 Feb 2023 01:48:34 -0500 Subject: [PATCH 4/9] docs: fix code example (#8880) --- docs/api-guide/serializers.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 0d14cac46..2b717b6e6 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -226,7 +226,7 @@ Individual fields on a serializer can include validators, by declaring them on t raise serializers.ValidationError('Not a multiple of ten') class GameRecord(serializers.Serializer): - score = IntegerField(validators=[multiple_of_ten]) + score = serializers.IntegerField(validators=[multiple_of_ten]) ... Serializer classes can also include reusable validators that are applied to the complete set of field data. These validators are included by declaring them on an inner `Meta` class, like so: From 390daf7a92bdf19cb8e08e1cc933c33c069da357 Mon Sep 17 00:00:00 2001 From: Jameel Al-Aziz <247849+jalaziz@users.noreply.github.com> Date: Tue, 21 Feb 2023 22:05:45 -0800 Subject: [PATCH 5/9] Upgrade isort version in pre-commit (#8882) This fixes recent issues with installing isort via pre-commit that was introduced in recent versions of poetry-core. See https://github.com/PyCQA/isort/pull/2078 --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5a6e554b9..83fe0b714 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - id: check-symlinks - id: check-toml - repo: https://github.com/pycqa/isort - rev: 5.8.0 + rev: 5.12.0 hooks: - id: isort - repo: https://github.com/PyCQA/flake8 From 15c613a9eb645c63102b9e894199bcf1c9bf4d65 Mon Sep 17 00:00:00 2001 From: Jameel Al-Aziz <247849+jalaziz@users.noreply.github.com> Date: Wed, 22 Feb 2023 07:39:01 -0800 Subject: [PATCH 6/9] Allow generic requests, responses, fields, views (#8825) Allow Request, Response, Field, and GenericAPIView to be subscriptable. This allows the classes to be made generic for type checking. This is especially useful since monkey patching DRF can be problematic as seen in this [issue][1]. [1]: https://github.com/typeddjango/djangorestframework-stubs/issues/299 --- rest_framework/fields.py | 4 ++++ rest_framework/generics.py | 4 ++++ rest_framework/request.py | 4 ++++ rest_framework/response.py | 4 ++++ tests/test_fields.py | 10 ++++++++++ tests/test_generics.py | 25 +++++++++++++++++++++++++ tests/test_request.py | 10 ++++++++++ tests/test_response.py | 12 ++++++++++++ 8 files changed, 73 insertions(+) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 1c6425596..613bd325a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -356,6 +356,10 @@ class Field: messages.update(error_messages or {}) self.error_messages = messages + # Allow generic typing checking for fields. + def __class_getitem__(cls, *args, **kwargs): + return cls + def bind(self, field_name, parent): """ Initializes the field name and parent for the field instance. diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda4..167303321 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -45,6 +45,10 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + # Allow generic typing checking for generic views. + def __class_getitem__(cls, *args, **kwargs): + return cls + def get_queryset(self): """ Get the list of items for this view. diff --git a/rest_framework/request.py b/rest_framework/request.py index 194be5f6d..93109226d 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -186,6 +186,10 @@ class Request: self.method, self.get_full_path()) + # Allow generic typing checking for requests. + def __class_getitem__(cls, *args, **kwargs): + return cls + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/response.py b/rest_framework/response.py index 495423734..6e756544c 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -46,6 +46,10 @@ class Response(SimpleTemplateResponse): for name, value in headers.items(): self[name] = value + # Allow generic typing checking for responses. + def __class_getitem__(cls, *args, **kwargs): + return cls + @property def rendered_content(self): renderer = getattr(self, 'accepted_renderer', None) diff --git a/tests/test_fields.py b/tests/test_fields.py index 512f3f789..56e2a45ba 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -2,6 +2,7 @@ import datetime import math import os import re +import sys import uuid from decimal import ROUND_DOWN, ROUND_UP, Decimal @@ -625,6 +626,15 @@ class Test5087Regression: assert field.root is parent +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_field_is_subscriptable(self): + assert serializers.Field is serializers.Field["foo"] + + # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_generics.py b/tests/test_generics.py index 78dc5afb6..9990389c9 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,3 +1,5 @@ +import sys + import pytest from django.db import models from django.http import Http404 @@ -698,3 +700,26 @@ class TestSerializer(TestCase): serializer = response.serializer assert serializer.context is context + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_genericview_is_subscriptable(self): + assert generics.GenericAPIView is generics.GenericAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_listview_is_subscriptable(self): + assert generics.ListAPIView is generics.ListAPIView["foo"] + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_instanceview_is_subscriptable(self): + assert generics.RetrieveAPIView is generics.RetrieveAPIView["foo"] diff --git a/tests/test_request.py b/tests/test_request.py index 8c18aea9e..e37aa7dda 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -3,6 +3,7 @@ Tests for content parsing, and form-overloaded content parsing. """ import copy import os.path +import sys import tempfile import pytest @@ -352,3 +353,12 @@ class TestDeepcopy(TestCase): def test_deepcopy_works(self): request = Request(factory.get('/', secure=False)) copy.deepcopy(request) + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_request_is_subscriptable(self): + assert Request is Request["foo"] diff --git a/tests/test_response.py b/tests/test_response.py index 0d5528dc9..cab19a1eb 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,3 +1,6 @@ +import sys + +import pytest from django.test import TestCase, override_settings from django.urls import include, path, re_path @@ -283,3 +286,12 @@ class Issue807Tests(TestCase): self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') # self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text description.') + + +class TestTyping(TestCase): + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="subscriptable classes requires Python 3.7 or higher", + ) + def test_response_is_subscriptable(self): + assert Response is Response["foo"] From 3f8ab538c1a7e6f887af9fec41847e2d67ff674f Mon Sep 17 00:00:00 2001 From: Jayant <39442192+RoguedBear@users.noreply.github.com> Date: Sun, 26 Feb 2023 18:06:48 +0000 Subject: [PATCH 7/9] docs: add missing renderer import in tutorial 6 (#8885) --- docs/tutorial/6-viewsets-and-routers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md index 74789e337..f9b6c5e9a 100644 --- a/docs/tutorial/6-viewsets-and-routers.md +++ b/docs/tutorial/6-viewsets-and-routers.md @@ -27,6 +27,7 @@ Here we've used the `ReadOnlyModelViewSet` class to automatically provide the de Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. from rest_framework import permissions + from rest_framework import renderers from rest_framework.decorators import action from rest_framework.response import Response From 9882207c160802e5a3655a1806a04f3357f11719 Mon Sep 17 00:00:00 2001 From: Asif Saif Uddin Date: Tue, 28 Feb 2023 22:02:12 +0600 Subject: [PATCH 8/9] test django 4.2b1 (#8892) --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6fd9d8695..7027612e0 100644 --- a/tox.ini +++ b/tox.ini @@ -21,7 +21,7 @@ deps = django32: Django>=3.2,<4.0 django40: Django>=4.0,<4.1 django41: Django>=4.1,<4.2 - django42: Django>=4.2a1,<5.0 + django42: Django>=4.2b1,<5.0 djangomain: https://github.com/django/django/archive/main.tar.gz -rrequirements/requirements-testing.txt -rrequirements/requirements-optionals.txt From b7523f4b9f5ec354cd50bd514784e6248be47a37 Mon Sep 17 00:00:00 2001 From: Konstantin Alekseev Date: Fri, 3 Mar 2023 09:04:47 +0200 Subject: [PATCH 9/9] Support UniqueConstraint (#7438) --- rest_framework/serializers.py | 80 ++++++++++++--------- rest_framework/utils/field_mapping.py | 29 ++++++-- tests/test_validators.py | 100 ++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 40 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eae6a0b2e..e27f8a47c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -1398,6 +1398,23 @@ class ModelSerializer(Serializer): return extra_kwargs + def get_unique_together_constraints(self, model): + """ + Returns iterator of (fields, queryset), each entry describes an unique together + constraint on `fields` in `queryset`. + """ + for parent_class in [model] + list(model._meta.parents): + for unique_together in parent_class._meta.unique_together: + yield unique_together, model._default_manager + for constraint in parent_class._meta.constraints: + if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1: + yield ( + constraint.fields, + model._default_manager + if constraint.condition is None + else model._default_manager.filter(constraint.condition) + ) + def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): """ Return any additional field options that need to be included as a @@ -1426,12 +1443,11 @@ class ModelSerializer(Serializer): unique_constraint_names -= {None} - # Include each of the `unique_together` field names, + # Include each of the `unique_together` and `UniqueConstraint` field names, # so long as all the field names are included on the serializer. - for parent_class in [model] + list(model._meta.parents): - for unique_together_list in parent_class._meta.unique_together: - if set(field_names).issuperset(unique_together_list): - unique_constraint_names |= set(unique_together_list) + for unique_together_list, queryset in self.get_unique_together_constraints(model): + if set(field_names).issuperset(unique_together_list): + unique_constraint_names |= set(unique_together_list) # Now we have all the field names that have uniqueness constraints # applied, we can add the extra 'required=...' or 'default=...' @@ -1526,11 +1542,6 @@ class ModelSerializer(Serializer): """ Determine a default set of validators for any unique_together constraints. """ - model_class_inheritance_tree = ( - [self.Meta.model] + - list(self.Meta.model._meta.parents) - ) - # The field names we're passing though here only include fields # which may map onto a model field. Any dotted field name lookups # cannot map to a field, and must be a traversal, so we're not @@ -1556,34 +1567,33 @@ class ModelSerializer(Serializer): # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] - for parent_class in model_class_inheritance_tree: - for unique_together in parent_class._meta.unique_together: - # Skip if serializer does not map to all unique together sources - if not set(source_map).issuperset(unique_together): - continue + for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model): + # Skip if serializer does not map to all unique together sources + if not set(source_map).issuperset(unique_together): + continue - for source in unique_together: - assert len(source_map[source]) == 1, ( - "Unable to create `UniqueTogetherValidator` for " - "`{model}.{field}` as `{serializer}` has multiple " - "fields ({fields}) that map to this model field. " - "Either remove the extra fields, or override " - "`Meta.validators` with a `UniqueTogetherValidator` " - "using the desired field names." - .format( - model=self.Meta.model.__name__, - serializer=self.__class__.__name__, - field=source, - fields=', '.join(source_map[source]), - ) + for source in unique_together: + assert len(source_map[source]) == 1, ( + "Unable to create `UniqueTogetherValidator` for " + "`{model}.{field}` as `{serializer}` has multiple " + "fields ({fields}) that map to this model field. " + "Either remove the extra fields, or override " + "`Meta.validators` with a `UniqueTogetherValidator` " + "using the desired field names." + .format( + model=self.Meta.model.__name__, + serializer=self.__class__.__name__, + field=source, + fields=', '.join(source_map[source]), ) - - field_names = tuple(source_map[f][0] for f in unique_together) - validator = UniqueTogetherValidator( - queryset=parent_class._default_manager, - fields=field_names ) - validators.append(validator) + + field_names = tuple(source_map[f][0] for f in unique_together) + validator = UniqueTogetherValidator( + queryset=queryset, + fields=field_names + ) + validators.append(validator) return validators def get_unique_for_date_validators(self): diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 7e8e8f046..fc63f96fe 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -62,6 +62,29 @@ def get_detail_view_name(model): } +def get_unique_validators(field_name, model_field): + """ + Returns a list of UniqueValidators that should be applied to the field. + """ + field_set = set([field_name]) + conditions = { + c.condition + for c in model_field.model._meta.constraints + if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set + } + if getattr(model_field, 'unique', False): + conditions.add(None) + if not conditions: + return + unique_error_message = get_unique_error_message(model_field) + queryset = model_field.model._default_manager + for condition in conditions: + yield UniqueValidator( + queryset=queryset if condition is None else queryset.filter(condition), + message=unique_error_message + ) + + def get_field_kwargs(field_name, model_field): """ Creates a default instance of a basic non-relational field. @@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field): if not isinstance(validator, validators.MinLengthValidator) ] - if getattr(model_field, 'unique', False): - validator = UniqueValidator( - queryset=model_field.model._default_manager, - message=get_unique_error_message(model_field)) - validator_kwarg.append(validator) + validator_kwarg += get_unique_validators(field_name, model_field) if validator_kwarg: kwargs['validators'] = validator_kwarg diff --git a/tests/test_validators.py b/tests/test_validators.py index 39490ac86..35fef6f26 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -464,6 +464,106 @@ class TestUniquenessTogetherValidation(TestCase): assert queryset.called_with == {'race_name': 'bar', 'position': 1} +class UniqueConstraintModel(models.Model): + race_name = models.CharField(max_length=100) + position = models.IntegerField() + global_id = models.IntegerField() + fancy_conditions = models.IntegerField(null=True) + + class Meta: + constraints = [ + models.UniqueConstraint( + name="unique_constraint_model_global_id_uniq", + fields=('global_id',), + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_1_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__lte=1) + ), + models.UniqueConstraint( + name="unique_constraint_model_fancy_3_uniq", + fields=('fancy_conditions',), + condition=models.Q(global_id__gte=3) + ), + models.UniqueConstraint( + name="unique_constraint_model_together_uniq", + fields=('race_name', 'position'), + condition=models.Q(race_name='example'), + ) + ] + + +class UniqueConstraintSerializer(serializers.ModelSerializer): + class Meta: + model = UniqueConstraintModel + fields = '__all__' + + +class TestUniqueConstraintValidation(TestCase): + def setUp(self): + self.instance = UniqueConstraintModel.objects.create( + race_name='example', + position=1, + global_id=1 + ) + UniqueConstraintModel.objects.create( + race_name='example', + position=2, + global_id=2 + ) + UniqueConstraintModel.objects.create( + race_name='other', + position=1, + global_id=3 + ) + + def test_repr(self): + serializer = UniqueConstraintSerializer() + # the order of validators isn't deterministic so delete + # fancy_conditions field that has two of them + del serializer.fields['fancy_conditions'] + expected = dedent(""" + UniqueConstraintSerializer(): + id = IntegerField(label='ID', read_only=True) + race_name = CharField(max_length=100, required=True) + position = IntegerField(required=True) + global_id = IntegerField(validators=[]) + class Meta: + validators = [, ]>, fields=('race_name', 'position'))>] + """) + assert repr(serializer) == expected + + def test_unique_together_field(self): + """ + UniqueConstraint fields and condition attributes must be passed + to UniqueTogetherValidator as fields and queryset + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validator = serializer.validators[0] + assert validator.fields == ('race_name', 'position') + assert set(validator.queryset.values_list(flat=True)) == set( + UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True) + ) + + def test_single_field_uniq_validators(self): + """ + UniqueConstraint with single field must be transformed into + field's UniqueValidator + """ + serializer = UniqueConstraintSerializer() + assert len(serializer.validators) == 1 + validators = serializer.fields['global_id'].validators + assert len(validators) == 1 + assert validators[0].queryset == UniqueConstraintModel.objects + + validators = serializer.fields['fancy_conditions'].validators + assert len(validators) == 2 + ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators} + assert ids_in_qs == {frozenset([1]), frozenset([3])} + + # Tests for `UniqueForDateValidator` # ----------------------------------