diff --git a/.travis.yml b/.travis.yml index f1ec689f7..39efaf4fc 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,13 +9,16 @@ matrix: - { python: "3.6", env: DJANGO=2.2 } - { python: "3.6", env: DJANGO=3.0 } + - { python: "3.6", env: DJANGO=3.1 } - { python: "3.6", env: DJANGO=master } - { python: "3.7", env: DJANGO=2.2 } - { python: "3.7", env: DJANGO=3.0 } + - { python: "3.7", env: DJANGO=3.1 } - { python: "3.7", env: DJANGO=master } - { python: "3.8", env: DJANGO=3.0 } + - { python: "3.8", env: DJANGO=3.1 } - { python: "3.8", env: DJANGO=master } - { python: "3.8", env: TOXENV=base } diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index 65c83b78e..b2bdd50c8 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -603,7 +603,7 @@ The `to_internal_value()` method is called to restore a primitive datatype into Let's look at an example of serializing a class that represents an RGB color value: - class Color(object): + class Color: """ A color represented in the RGB colorspace. """ diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index a256eb2d9..4ff549f07 100644 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -319,7 +319,7 @@ Often you'll want to use the existing generic views, but use some slightly custo For example, if you need to lookup objects based on multiple fields in the URL conf, you could create a mixin class like the following: - class MultipleFieldLookupMixin(object): + class MultipleFieldLookupMixin: """ Apply this mixin to any view or viewset to get multiple field filtering based on a `lookup_fields` attribute, instead of the default single field filtering. diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 87d3d4056..4f566ff59 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -21,7 +21,7 @@ Let's start by creating a simple object we can use for example purposes: from datetime import datetime - class Comment(object): + class Comment: def __init__(self, email, content, created=None): self.email = email self.content = content diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index 009cd2468..4451489d4 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -282,7 +282,7 @@ to your `Serializer` subclass. This is documented in the To write a class-based validator, use the `__call__` method. Class-based validators are useful as they allow you to parameterize and reuse behavior. - class MultipleOf(object): + class MultipleOf: def __init__(self, base): self.base = base diff --git a/rest_framework/compat.py b/rest_framework/compat.py index df100966b..611068a62 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -2,75 +2,9 @@ The `compat` module provides support for backwards compatibility with older versions of Django/Python, and compatibility wrappers around optional packages. """ -import sys - from django.conf import settings from django.views.generic import View -try: - from django.urls import ( # noqa - URLPattern, - URLResolver, - ) -except ImportError: - # Will be removed in Django 2.0 - from django.urls import ( # noqa - RegexURLPattern as URLPattern, - RegexURLResolver as URLResolver, - ) - -try: - from django.core.validators import ProhibitNullCharactersValidator # noqa -except ImportError: - ProhibitNullCharactersValidator = None - - -def get_original_route(urlpattern): - """ - Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This - is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path(). - """ - if hasattr(urlpattern, 'pattern'): - # Django 2.0 - return str(urlpattern.pattern) - else: - # Django < 2.0 - return urlpattern.regex.pattern - - -def get_regex_pattern(urlpattern): - """ - Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression, - unlike get_original_route above. - """ - if hasattr(urlpattern, 'pattern'): - # Django 2.0 - return urlpattern.pattern.regex.pattern - else: - # Django < 2.0 - return urlpattern.regex.pattern - - -def is_route_pattern(urlpattern): - if hasattr(urlpattern, 'pattern'): - # Django 2.0 - from django.urls.resolvers import RoutePattern - return isinstance(urlpattern.pattern, RoutePattern) - else: - # Django < 2.0 - return False - - -def make_url_resolver(regex, urlpatterns): - try: - # Django 2.0 - from django.urls.resolvers import RegexPattern - return URLResolver(RegexPattern(regex), urlpatterns) - - except ImportError: - # Django < 2.0 - return URLResolver(regex, urlpatterns) - def unicode_http_header(value): # Coerce HTTP header value to unicode. @@ -217,22 +151,8 @@ else: return False -# Django 1.x url routing syntax. Remove when dropping Django 1.11 support. -try: - from django.urls import include, path, re_path, register_converter # noqa -except ImportError: - from django.conf.urls import include, url # noqa - path = None - register_converter = None - re_path = url - - # `separators` argument to `json.dumps()` differs between 2.x and 3.x # See: https://bugs.python.org/issue22767 SHORT_SEPARATORS = (',', ':') LONG_SEPARATORS = (', ', ': ') INDENT_SEPARATORS = (',', ': ') - - -# Version Constants. -PY36 = sys.version_info >= (3, 6) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 958bebeef..da2dd54be 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -14,7 +14,8 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError from django.core.validators import ( EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, - MinValueValidator, RegexValidator, URLValidator, ip_address_validators + MinValueValidator, ProhibitNullCharactersValidator, RegexValidator, + URLValidator, ip_address_validators ) from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField @@ -30,8 +31,9 @@ from django.utils.timezone import utc from django.utils.translation import gettext_lazy as _ from pytz.exceptions import InvalidTimeError -from rest_framework import ISO_8601, RemovedInDRF313Warning -from rest_framework.compat import ProhibitNullCharactersValidator +from rest_framework import ( + ISO_8601, RemovedInDRF313Warning, RemovedInDRF314Warning +) from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, json, representation @@ -740,55 +742,22 @@ class BooleanField(Field): return bool(value) -class NullBooleanField(Field): - default_error_messages = { - 'invalid': _('Must be a valid boolean.') - } +class NullBooleanField(BooleanField): initial = None - TRUE_VALUES = { - 't', 'T', - 'y', 'Y', 'yes', 'YES', - 'true', 'True', 'TRUE', - 'on', 'On', 'ON', - '1', 1, - True - } - FALSE_VALUES = { - 'f', 'F', - 'n', 'N', 'no', 'NO', - 'false', 'False', 'FALSE', - 'off', 'Off', 'OFF', - '0', 0, 0.0, - False - } - NULL_VALUES = {'null', 'Null', 'NULL', '', None} def __init__(self, **kwargs): + warnings.warn( + "The `NullBooleanField` is deprecated and will be removed starting " + "with 3.14. Instead use the `BooleanField` field and set " + "`null=True` which does the same thing.", + RemovedInDRF314Warning, stacklevel=2 + ) + assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' kwargs['allow_null'] = True + super().__init__(**kwargs) - def to_internal_value(self, data): - try: - if data in self.TRUE_VALUES: - return True - elif data in self.FALSE_VALUES: - return False - elif data in self.NULL_VALUES: - return None - except TypeError: # Input is an unhashable type - pass - self.fail('invalid', input=data) - - def to_representation(self, value): - if value in self.NULL_VALUES: - return None - if value in self.TRUE_VALUES: - return True - elif value in self.FALSE_VALUES: - return False - return bool(value) - # String types... @@ -816,9 +785,7 @@ class CharField(Field): self.validators.append( MinLengthValidator(self.min_length, message=message)) - # ProhibitNullCharactersValidator is None on Django < 2.0 - if ProhibitNullCharactersValidator is not None: - self.validators.append(ProhibitNullCharactersValidator()) + self.validators.append(ProhibitNullCharactersValidator()) self.validators.append(ProhibitSurrogateCharactersValidator()) def run_validation(self, data=empty): diff --git a/rest_framework/schemas/generators.py b/rest_framework/schemas/generators.py index 4b6d82a14..d3c6446aa 100644 --- a/rest_framework/schemas/generators.py +++ b/rest_framework/schemas/generators.py @@ -10,9 +10,9 @@ from django.conf import settings from django.contrib.admindocs.views import simplify_regex from django.core.exceptions import PermissionDenied from django.http import Http404 +from django.urls import URLPattern, URLResolver from rest_framework import exceptions -from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.request import clone_request from rest_framework.settings import api_settings from rest_framework.utils.model_meta import _get_pk @@ -79,7 +79,7 @@ class EndpointEnumerator: api_endpoints = [] for pattern in patterns: - path_regex = prefix + get_original_route(pattern) + path_regex = prefix + str(pattern.pattern) if isinstance(pattern, URLPattern): path = self.get_path_from_regex(path_regex) callback = pattern.callback @@ -143,7 +143,7 @@ class EndpointEnumerator: return [method for method in methods if method not in ('OPTIONS', 'HEAD')] -class BaseSchemaGenerator(object): +class BaseSchemaGenerator: endpoint_inspector_cls = EndpointEnumerator # 'pk' isn't great as an externally exposed name for an identifier, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 8c2486bea..cfb54de13 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -13,7 +13,7 @@ response content is handled by parsers and renderers. import copy import inspect import traceback -from collections import OrderedDict +from collections import OrderedDict, defaultdict from collections.abc import Mapping from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured @@ -868,7 +868,7 @@ class ModelSerializer(Serializer): models.FloatField: FloatField, models.ImageField: ImageField, models.IntegerField: IntegerField, - models.NullBooleanField: NullBooleanField, + models.NullBooleanField: BooleanField, models.PositiveIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField, models.SlugField: SlugField, @@ -1508,28 +1508,55 @@ class ModelSerializer(Serializer): # which may map onto a model field. Any dotted field name lookups # cannot map to a field, and must be a traversal, so we're not # including those. - field_names = { - field.source for field in self._writable_fields + field_sources = OrderedDict( + (field.field_name, field.source) for field in self._writable_fields if (field.source != '*') and ('.' not in field.source) - } + ) # Special Case: Add read_only fields with defaults. - field_names |= { - field.source for field in self.fields.values() + field_sources.update(OrderedDict( + (field.field_name, field.source) for field in self.fields.values() if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) - } + )) + + # Invert so we can find the serializer field names that correspond to + # the model field names in the unique_together sets. This also allows + # us to check that multiple fields don't map to the same source. + source_map = defaultdict(list) + for name, source in field_sources.items(): + source_map[source].append(name) # Note that we make sure to check `unique_together` both on the # base model class, but also on any parent classes. validators = [] for parent_class in model_class_inheritance_tree: for unique_together in parent_class._meta.unique_together: - if field_names.issuperset(set(unique_together)): - validator = UniqueTogetherValidator( - queryset=parent_class._default_manager, - fields=unique_together + # Skip if serializer does not map to all unique together sources + if not set(source_map).issuperset(set(unique_together)): + continue + + for source in unique_together: + assert len(source_map[source]) == 1, ( + "Unable to create `UniqueTogetherValidator` for " + "`{model}.{field}` as `{serializer}` has multiple " + "fields ({fields}) that map to this model field. " + "Either remove the extra fields, or override " + "`Meta.validators` with a `UniqueTogetherValidator` " + "using the desired field names." + .format( + model=self.Meta.model.__name__, + serializer=self.__class__.__name__, + field=source, + fields=', '.join(source_map[source]), + ) ) - validators.append(validator) + + field_names = tuple(source_map[f][0] for f in unique_together) + validator = UniqueTogetherValidator( + queryset=parent_class._default_manager, + fields=field_names + ) + validators.append(validator) return validators def get_unique_for_date_validators(self): diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 831d344dd..5b0bb4440 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,7 @@ from django.conf.urls import include, url +from django.urls import URLResolver, path, register_converter +from django.urls.resolvers import RoutePattern -from rest_framework.compat import ( - URLResolver, get_regex_pattern, is_route_pattern, path, register_converter -) from rest_framework.settings import api_settings @@ -37,7 +36,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r for urlpattern in urlpatterns: if isinstance(urlpattern, URLResolver): # Set of included URL patterns - regex = get_regex_pattern(urlpattern) + regex = urlpattern.pattern.regex.pattern namespace = urlpattern.namespace app_name = urlpattern.app_name kwargs = urlpattern.default_kwargs @@ -48,7 +47,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r suffix_route) # if the original pattern was a RoutePattern we need to preserve it - if is_route_pattern(urlpattern): + if isinstance(urlpattern.pattern, RoutePattern): assert path is not None route = str(urlpattern.pattern) new_pattern = path(route, include((patterns, app_name), namespace), kwargs) @@ -58,7 +57,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r ret.append(new_pattern) else: # Regular URL pattern - regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern + regex = urlpattern.pattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern view = urlpattern.callback kwargs = urlpattern.default_args name = urlpattern.name @@ -67,7 +66,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r ret.append(urlpattern) # if the original pattern was a RoutePattern we need to preserve it - if is_route_pattern(urlpattern): + if isinstance(urlpattern.pattern, RoutePattern): assert path is not None assert suffix_route is not None route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index a25880d0f..ed270be5e 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -104,7 +104,7 @@ def get_field_kwargs(field_name, model_field): if model_field.has_default() or model_field.blank or model_field.null: kwargs['required'] = False - if model_field.null and not isinstance(model_field, models.NullBooleanField): + if model_field.null: kwargs['allow_null'] = True if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))): diff --git a/setup.py b/setup.py index 99826b4d0..38e680e10 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ setup( 'Framework :: Django', 'Framework :: Django :: 2.2', 'Framework :: Django :: 3.0', + 'Framework :: Django :: 3.1', 'Intended Audience :: Developers', 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', diff --git a/tests/schemas/test_coreapi.py b/tests/schemas/test_coreapi.py index a634d6968..403b3b634 100644 --- a/tests/schemas/test_coreapi.py +++ b/tests/schemas/test_coreapi.py @@ -5,11 +5,12 @@ from django.conf.urls import include, url from django.core.exceptions import PermissionDenied from django.http import Http404 from django.test import TestCase, override_settings +from django.urls import path from rest_framework import ( filters, generics, pagination, permissions, serializers ) -from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path +from rest_framework.compat import coreapi, coreschema from rest_framework.decorators import action, api_view, schema from rest_framework.request import Request from rest_framework.routers import DefaultRouter, SimpleRouter @@ -1078,7 +1079,7 @@ class SchemaGenerationExclusionTests(TestCase): inspector = EndpointEnumerator(self.patterns) # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test - pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) + pairs = [(inspector.get_path_from_regex(pattern.pattern.regex.pattern), pattern.callback) for pattern in self.patterns] should_include = [ diff --git a/tests/test_fields.py b/tests/test_fields.py index a4b78fd51..b1ad1dc66 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -13,7 +13,6 @@ from django.utils.timezone import activate, deactivate, override, utc import rest_framework from rest_framework import exceptions, serializers -from rest_framework.compat import ProhibitNullCharactersValidator from rest_framework.fields import ( BuiltinSignatureError, DjangoImageField, is_simple_callable ) @@ -747,7 +746,6 @@ class TestCharField(FieldValues): field.run_validation(' ') assert exc_info.value.detail == ['This field may not be blank.'] - @pytest.mark.skipif(ProhibitNullCharactersValidator is None, reason="Skipped on Django < 2.0") def test_null_bytes(self): field = serializers.CharField() @@ -762,8 +760,8 @@ class TestCharField(FieldValues): field = serializers.CharField() for code_point, expected_message in ( - (0xD800, 'Surrogate characters are not allowed: U+D800.'), - (0xDFFF, 'Surrogate characters are not allowed: U+DFFF.'), + (0xD800, 'Surrogate characters are not allowed: U+D800.'), + (0xDFFF, 'Surrogate characters are not allowed: U+DFFF.'), ): with pytest.raises(serializers.ValidationError) as exc_info: field.run_validation(chr(code_point)) diff --git a/tests/test_filters.py b/tests/test_filters.py index e69537666..567e5f83f 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -1,7 +1,6 @@ import datetime from importlib import reload as reload_module -import django import pytest from django.core.exceptions import ImproperlyConfigured from django.db import models @@ -191,7 +190,6 @@ class SearchFilterTests(TestCase): assert terms == ['asdf'] - @pytest.mark.skipif(django.VERSION[:2] < (2, 2), reason="requires django 2.2 or higher") def test_search_field_with_additional_transforms(self): from django.test.utils import register_lookup diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 0de628dc8..51b8f2e22 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -182,7 +182,7 @@ class TestRegularFieldMappings(TestCase): email_field = EmailField(max_length=100) float_field = FloatField() integer_field = IntegerField() - null_boolean_field = NullBooleanField(required=False) + null_boolean_field = BooleanField(allow_null=True, required=False) positive_integer_field = IntegerField() positive_small_integer_field = IntegerField() slug_field = SlugField(allow_unicode=False, max_length=100) @@ -236,6 +236,27 @@ class TestRegularFieldMappings(TestCase): self.assertEqual(repr(NullableBooleanSerializer()), expected) + def test_nullable_boolean_field_choices(self): + class NullableBooleanChoicesModel(models.Model): + CHECKLIST_OPTIONS = ( + (None, 'Unknown'), + (True, 'Yes'), + (False, 'No'), + ) + + field = models.BooleanField(null=True, choices=CHECKLIST_OPTIONS) + + class NullableBooleanChoicesSerializer(serializers.ModelSerializer): + class Meta: + model = NullableBooleanChoicesModel + fields = ['field'] + + serializer = NullableBooleanChoicesSerializer(data=dict( + field=None, + )) + self.assertTrue(serializer.is_valid()) + self.assertEqual(serializer.errors, {}) + def test_method_field(self): """ Properties and methods on the model should be allowed as `Meta.fields` diff --git a/tests/test_permissions.py b/tests/test_permissions.py index d445f271d..232c72dd2 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -3,7 +3,6 @@ import unittest from unittest import mock import django -import pytest from django.conf import settings from django.contrib.auth.models import AnonymousUser, Group, Permission, User from django.db import models @@ -14,7 +13,6 @@ from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status, views ) -from rest_framework.compat import PY36 from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory from tests.models import BasicModel @@ -607,7 +605,6 @@ class PermissionsCompositionTests(TestCase): ) assert composed_perm().has_permission(request, None) is True - @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") def test_or_lazyness(self): request = factory.get('/1', format='json') request.user = AnonymousUser() @@ -616,19 +613,18 @@ class PermissionsCompositionTests(TestCase): with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, True) - mock_allow.assert_called_once() + assert hasperm is True + assert mock_allow.call_count == 1 mock_deny.assert_not_called() with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, True) - mock_deny.assert_called_once() - mock_allow.assert_called_once() + assert hasperm is True + assert mock_deny.call_count == 1 + assert mock_allow.call_count == 1 - @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") def test_object_or_lazyness(self): request = factory.get('/1', format='json') request.user = AnonymousUser() @@ -637,19 +633,18 @@ class PermissionsCompositionTests(TestCase): with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, True) - mock_allow.assert_called_once() + assert hasperm is True + assert mock_allow.call_count == 1 mock_deny.assert_not_called() with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, True) - mock_deny.assert_called_once() - mock_allow.assert_called_once() + assert hasperm is True + assert mock_deny.call_count == 1 + assert mock_allow.call_count == 1 - @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") def test_and_lazyness(self): request = factory.get('/1', format='json') request.user = AnonymousUser() @@ -658,19 +653,18 @@ class PermissionsCompositionTests(TestCase): with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, False) - mock_allow.assert_called_once() - mock_deny.assert_called_once() + assert hasperm is False + assert mock_allow.call_count == 1 + assert mock_deny.call_count == 1 with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) hasperm = composed_perm().has_permission(request, None) - self.assertIs(hasperm, False) + assert hasperm is False + assert mock_deny.call_count == 1 mock_allow.assert_not_called() - mock_deny.assert_called_once() - @pytest.mark.skipif(not PY36, reason="assert_called_once() not available") def test_object_and_lazyness(self): request = factory.get('/1', format='json') request.user = AnonymousUser() @@ -679,14 +673,14 @@ class PermissionsCompositionTests(TestCase): with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, False) - mock_allow.assert_called_once() - mock_deny.assert_called_once() + assert hasperm is False + assert mock_allow.call_count == 1 + assert mock_deny.call_count == 1 with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) hasperm = composed_perm().has_object_permission(request, None, None) - self.assertIs(hasperm, False) + assert hasperm is False + assert mock_deny.call_count == 1 mock_allow.assert_not_called() - mock_deny.assert_called_once() diff --git a/tests/test_routers.py b/tests/test_routers.py index ff927ff33..007cb4768 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -8,7 +8,6 @@ from django.test import TestCase, override_settings from django.urls import resolve, reverse from rest_framework import permissions, serializers, viewsets -from rest_framework.compat import get_regex_pattern from rest_framework.decorators import action from rest_framework.response import Response from rest_framework.routers import DefaultRouter, SimpleRouter @@ -192,8 +191,7 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase): def test_custom_lookup_field_route(self): detail_route = notes_router.urls[-1] - detail_url_pattern = get_regex_pattern(detail_route) - assert '' in detail_url_pattern + assert '' in detail_route.pattern.regex.pattern def test_retrieve_lookup_field_list_view(self): response = self.client.get('/example/notes/') @@ -229,7 +227,7 @@ class TestLookupValueRegex(TestCase): def test_urls_limited_by_lookup_value_regex(self): expected = ['^notes/$', '^notes/(?P[0-9a-f]{32})/$'] for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + assert expected[idx] == self.urls[idx].pattern.regex.pattern @override_settings(ROOT_URLCONF='tests.test_routers') @@ -249,8 +247,7 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase): def test_custom_lookup_url_kwarg_route(self): detail_route = kwarged_notes_router.urls[-1] - detail_url_pattern = get_regex_pattern(detail_route) - assert '^notes/(?P' in detail_url_pattern + assert '^notes/(?P' in detail_route.pattern.regex.pattern def test_retrieve_lookup_url_kwarg_detail_view(self): response = self.client.get('/example2/notes/fo/') @@ -273,7 +270,7 @@ class TestTrailingSlashIncluded(TestCase): def test_urls_have_trailing_slash_by_default(self): expected = ['^notes/$', '^notes/(?P[^/.]+)/$'] for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + assert expected[idx] == self.urls[idx].pattern.regex.pattern class TestTrailingSlashRemoved(TestCase): @@ -288,7 +285,7 @@ class TestTrailingSlashRemoved(TestCase): def test_urls_can_have_trailing_slash_removed(self): expected = ['^notes$', '^notes/(?P[^/.]+)$'] for idx in range(len(expected)): - assert expected[idx] == get_regex_pattern(self.urls[idx]) + assert expected[idx] == self.urls[idx].pattern.regex.pattern class TestNameableRoot(TestCase): diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py index f3b0591b8..25e95df15 100644 --- a/tests/test_urlpatterns.py +++ b/tests/test_urlpatterns.py @@ -2,9 +2,9 @@ import unittest from collections import namedtuple from django.test import TestCase -from django.urls import Resolver404, include +from django.urls import Resolver404, URLResolver, path, re_path +from django.urls.resolvers import RegexPattern -from rest_framework.compat import make_url_resolver, path from rest_framework.test import APIRequestFactory from rest_framework.urlpatterns import format_suffix_patterns @@ -27,7 +27,7 @@ class FormatSuffixTests(TestCase): urlpatterns = format_suffix_patterns(urlpatterns, allowed=allowed) except Exception: self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") - resolver = make_url_resolver(r'^/', urlpatterns) + resolver = URLResolver(RegexPattern(r'^/'), urlpatterns) for test_path in test_paths: try: test_path, expected_resolved = test_path diff --git a/tests/test_validators.py b/tests/test_validators.py index 21c00073d..4962cf581 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -344,6 +344,49 @@ class TestUniquenessTogetherValidation(TestCase): ] } + def test_default_validator_with_fields_with_source(self): + class TestSerializer(serializers.ModelSerializer): + name = serializers.CharField(source='race_name') + + class Meta: + model = UniquenessTogetherModel + fields = ['name', 'position'] + + serializer = TestSerializer() + expected = dedent(""" + TestSerializer(): + name = CharField(source='race_name') + position = IntegerField() + class Meta: + validators = [] + """) + assert repr(serializer) == expected + + def test_default_validator_with_multiple_fields_with_same_source(self): + class TestSerializer(serializers.ModelSerializer): + name = serializers.CharField(source='race_name') + other_name = serializers.CharField(source='race_name') + + class Meta: + model = UniquenessTogetherModel + fields = ['name', 'other_name', 'position'] + + serializer = TestSerializer(data={ + 'name': 'foo', + 'other_name': 'foo', + 'position': 1, + }) + with pytest.raises(AssertionError) as excinfo: + serializer.is_valid() + + expected = ( + "Unable to create `UniqueTogetherValidator` for " + "`UniquenessTogetherModel.race_name` as `TestSerializer` has " + "multiple fields (name, other_name) that map to this model field. " + "Either remove the extra fields, or override `Meta.validators` " + "with a `UniqueTogetherValidator` using the desired field names.") + assert str(excinfo.value) == expected + def test_allow_explict_override(self): """ Ensure validators can be explicitly removed.. diff --git a/tox.ini b/tox.ini index e5b8b6402..190865f23 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,7 @@ envlist = {py35,py36,py37}-django22, {py36,py37,py38}-django30, + {py36,py37,py38}-django31, {py36,py37,py38}-djangomaster, base,dist,lint,docs, @@ -9,6 +10,7 @@ envlist = DJANGO = 2.2: django22 3.0: django30 + 3.1: django31 master: djangomaster [testenv] @@ -20,6 +22,7 @@ setenv = deps = django22: Django>=2.2,<3.0 django30: Django>=3.0,<3.1 + django31: Django>=3.1a1,<3.2 djangomaster: https://github.com/django/django/archive/master.tar.gz -rrequirements/requirements-testing.txt -rrequirements/requirements-optionals.txt