Merge branch 'master' into url

This commit is contained in:
Asif Saif Uddin 2020-05-15 14:31:53 +06:00 committed by GitHub
commit 7b4139d046
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 174 additions and 202 deletions

View File

@ -9,13 +9,16 @@ matrix:
- { python: "3.6", env: DJANGO=2.2 } - { python: "3.6", env: DJANGO=2.2 }
- { python: "3.6", env: DJANGO=3.0 } - { python: "3.6", env: DJANGO=3.0 }
- { python: "3.6", env: DJANGO=3.1 }
- { python: "3.6", env: DJANGO=master } - { python: "3.6", env: DJANGO=master }
- { python: "3.7", env: DJANGO=2.2 } - { python: "3.7", env: DJANGO=2.2 }
- { python: "3.7", env: DJANGO=3.0 } - { python: "3.7", env: DJANGO=3.0 }
- { python: "3.7", env: DJANGO=3.1 }
- { python: "3.7", env: DJANGO=master } - { python: "3.7", env: DJANGO=master }
- { python: "3.8", env: DJANGO=3.0 } - { python: "3.8", env: DJANGO=3.0 }
- { python: "3.8", env: DJANGO=3.1 }
- { python: "3.8", env: DJANGO=master } - { python: "3.8", env: DJANGO=master }
- { python: "3.8", env: TOXENV=base } - { python: "3.8", env: TOXENV=base }

View File

@ -603,7 +603,7 @@ The `to_internal_value()` method is called to restore a primitive datatype into
Let's look at an example of serializing a class that represents an RGB color value: Let's look at an example of serializing a class that represents an RGB color value:
class Color(object): class Color:
""" """
A color represented in the RGB colorspace. A color represented in the RGB colorspace.
""" """

View File

@ -319,7 +319,7 @@ Often you'll want to use the existing generic views, but use some slightly custo
For example, if you need to lookup objects based on multiple fields in the URL conf, you could create a mixin class like the following: For example, if you need to lookup objects based on multiple fields in the URL conf, you could create a mixin class like the following:
class MultipleFieldLookupMixin(object): class MultipleFieldLookupMixin:
""" """
Apply this mixin to any view or viewset to get multiple field filtering Apply this mixin to any view or viewset to get multiple field filtering
based on a `lookup_fields` attribute, instead of the default single field filtering. based on a `lookup_fields` attribute, instead of the default single field filtering.

View File

@ -21,7 +21,7 @@ Let's start by creating a simple object we can use for example purposes:
from datetime import datetime from datetime import datetime
class Comment(object): class Comment:
def __init__(self, email, content, created=None): def __init__(self, email, content, created=None):
self.email = email self.email = email
self.content = content self.content = content

View File

@ -282,7 +282,7 @@ to your `Serializer` subclass. This is documented in the
To write a class-based validator, use the `__call__` method. Class-based validators are useful as they allow you to parameterize and reuse behavior. To write a class-based validator, use the `__call__` method. Class-based validators are useful as they allow you to parameterize and reuse behavior.
class MultipleOf(object): class MultipleOf:
def __init__(self, base): def __init__(self, base):
self.base = base self.base = base

View File

@ -2,75 +2,9 @@
The `compat` module provides support for backwards compatibility with older The `compat` module provides support for backwards compatibility with older
versions of Django/Python, and compatibility wrappers around optional packages. versions of Django/Python, and compatibility wrappers around optional packages.
""" """
import sys
from django.conf import settings from django.conf import settings
from django.views.generic import View from django.views.generic import View
try:
from django.urls import ( # noqa
URLPattern,
URLResolver,
)
except ImportError:
# Will be removed in Django 2.0
from django.urls import ( # noqa
RegexURLPattern as URLPattern,
RegexURLResolver as URLResolver,
)
try:
from django.core.validators import ProhibitNullCharactersValidator # noqa
except ImportError:
ProhibitNullCharactersValidator = None
def get_original_route(urlpattern):
"""
Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This
is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path().
"""
if hasattr(urlpattern, 'pattern'):
# Django 2.0
return str(urlpattern.pattern)
else:
# Django < 2.0
return urlpattern.regex.pattern
def get_regex_pattern(urlpattern):
"""
Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression,
unlike get_original_route above.
"""
if hasattr(urlpattern, 'pattern'):
# Django 2.0
return urlpattern.pattern.regex.pattern
else:
# Django < 2.0
return urlpattern.regex.pattern
def is_route_pattern(urlpattern):
if hasattr(urlpattern, 'pattern'):
# Django 2.0
from django.urls.resolvers import RoutePattern
return isinstance(urlpattern.pattern, RoutePattern)
else:
# Django < 2.0
return False
def make_url_resolver(regex, urlpatterns):
try:
# Django 2.0
from django.urls.resolvers import RegexPattern
return URLResolver(RegexPattern(regex), urlpatterns)
except ImportError:
# Django < 2.0
return URLResolver(regex, urlpatterns)
def unicode_http_header(value): def unicode_http_header(value):
# Coerce HTTP header value to unicode. # Coerce HTTP header value to unicode.
@ -217,22 +151,8 @@ else:
return False return False
# Django 1.x url routing syntax. Remove when dropping Django 1.11 support.
try:
from django.urls import include, path, re_path, register_converter # noqa
except ImportError:
from django.conf.urls import include, url # noqa
path = None
register_converter = None
re_path = url
# `separators` argument to `json.dumps()` differs between 2.x and 3.x # `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767 # See: https://bugs.python.org/issue22767
SHORT_SEPARATORS = (',', ':') SHORT_SEPARATORS = (',', ':')
LONG_SEPARATORS = (', ', ': ') LONG_SEPARATORS = (', ', ': ')
INDENT_SEPARATORS = (',', ': ') INDENT_SEPARATORS = (',', ': ')
# Version Constants.
PY36 = sys.version_info >= (3, 6)

View File

@ -14,7 +14,8 @@ from django.core.exceptions import ObjectDoesNotExist
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import ( from django.core.validators import (
EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator,
MinValueValidator, RegexValidator, URLValidator, ip_address_validators MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
URLValidator, ip_address_validators
) )
from django.forms import FilePathField as DjangoFilePathField from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField from django.forms import ImageField as DjangoImageField
@ -30,8 +31,9 @@ from django.utils.timezone import utc
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from pytz.exceptions import InvalidTimeError from pytz.exceptions import InvalidTimeError
from rest_framework import ISO_8601, RemovedInDRF313Warning from rest_framework import (
from rest_framework.compat import ProhibitNullCharactersValidator ISO_8601, RemovedInDRF313Warning, RemovedInDRF314Warning
)
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, humanize_datetime, json, representation from rest_framework.utils import html, humanize_datetime, json, representation
@ -740,55 +742,22 @@ class BooleanField(Field):
return bool(value) return bool(value)
class NullBooleanField(Field): class NullBooleanField(BooleanField):
default_error_messages = {
'invalid': _('Must be a valid boolean.')
}
initial = None initial = None
TRUE_VALUES = {
't', 'T',
'y', 'Y', 'yes', 'YES',
'true', 'True', 'TRUE',
'on', 'On', 'ON',
'1', 1,
True
}
FALSE_VALUES = {
'f', 'F',
'n', 'N', 'no', 'NO',
'false', 'False', 'FALSE',
'off', 'Off', 'OFF',
'0', 0, 0.0,
False
}
NULL_VALUES = {'null', 'Null', 'NULL', '', None}
def __init__(self, **kwargs): def __init__(self, **kwargs):
warnings.warn(
"The `NullBooleanField` is deprecated and will be removed starting "
"with 3.14. Instead use the `BooleanField` field and set "
"`null=True` which does the same thing.",
RemovedInDRF314Warning, stacklevel=2
)
assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.' assert 'allow_null' not in kwargs, '`allow_null` is not a valid option.'
kwargs['allow_null'] = True kwargs['allow_null'] = True
super().__init__(**kwargs) super().__init__(**kwargs)
def to_internal_value(self, data):
try:
if data in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
return False
elif data in self.NULL_VALUES:
return None
except TypeError: # Input is an unhashable type
pass
self.fail('invalid', input=data)
def to_representation(self, value):
if value in self.NULL_VALUES:
return None
if value in self.TRUE_VALUES:
return True
elif value in self.FALSE_VALUES:
return False
return bool(value)
# String types... # String types...
@ -816,9 +785,7 @@ class CharField(Field):
self.validators.append( self.validators.append(
MinLengthValidator(self.min_length, message=message)) MinLengthValidator(self.min_length, message=message))
# ProhibitNullCharactersValidator is None on Django < 2.0 self.validators.append(ProhibitNullCharactersValidator())
if ProhibitNullCharactersValidator is not None:
self.validators.append(ProhibitNullCharactersValidator())
self.validators.append(ProhibitSurrogateCharactersValidator()) self.validators.append(ProhibitSurrogateCharactersValidator())
def run_validation(self, data=empty): def run_validation(self, data=empty):

View File

@ -10,9 +10,9 @@ from django.conf import settings
from django.contrib.admindocs.views import simplify_regex from django.contrib.admindocs.views import simplify_regex
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.urls import URLPattern, URLResolver
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils.model_meta import _get_pk from rest_framework.utils.model_meta import _get_pk
@ -79,7 +79,7 @@ class EndpointEnumerator:
api_endpoints = [] api_endpoints = []
for pattern in patterns: for pattern in patterns:
path_regex = prefix + get_original_route(pattern) path_regex = prefix + str(pattern.pattern)
if isinstance(pattern, URLPattern): if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex) path = self.get_path_from_regex(path_regex)
callback = pattern.callback callback = pattern.callback
@ -143,7 +143,7 @@ class EndpointEnumerator:
return [method for method in methods if method not in ('OPTIONS', 'HEAD')] return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
class BaseSchemaGenerator(object): class BaseSchemaGenerator:
endpoint_inspector_cls = EndpointEnumerator endpoint_inspector_cls = EndpointEnumerator
# 'pk' isn't great as an externally exposed name for an identifier, # 'pk' isn't great as an externally exposed name for an identifier,

View File

@ -13,7 +13,7 @@ response content is handled by parsers and renderers.
import copy import copy
import inspect import inspect
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict, defaultdict
from collections.abc import Mapping from collections.abc import Mapping
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
@ -868,7 +868,7 @@ class ModelSerializer(Serializer):
models.FloatField: FloatField, models.FloatField: FloatField,
models.ImageField: ImageField, models.ImageField: ImageField,
models.IntegerField: IntegerField, models.IntegerField: IntegerField,
models.NullBooleanField: NullBooleanField, models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField, models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField, models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField, models.SlugField: SlugField,
@ -1508,28 +1508,55 @@ class ModelSerializer(Serializer):
# which may map onto a model field. Any dotted field name lookups # which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not # cannot map to a field, and must be a traversal, so we're not
# including those. # including those.
field_names = { field_sources = OrderedDict(
field.source for field in self._writable_fields (field.field_name, field.source) for field in self._writable_fields
if (field.source != '*') and ('.' not in field.source) if (field.source != '*') and ('.' not in field.source)
} )
# Special Case: Add read_only fields with defaults. # Special Case: Add read_only fields with defaults.
field_names |= { field_sources.update(OrderedDict(
field.source for field in self.fields.values() (field.field_name, field.source) for field in self.fields.values()
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source) if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
} ))
# Invert so we can find the serializer field names that correspond to
# the model field names in the unique_together sets. This also allows
# us to check that multiple fields don't map to the same source.
source_map = defaultdict(list)
for name, source in field_sources.items():
source_map[source].append(name)
# Note that we make sure to check `unique_together` both on the # Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes. # base model class, but also on any parent classes.
validators = [] validators = []
for parent_class in model_class_inheritance_tree: for parent_class in model_class_inheritance_tree:
for unique_together in parent_class._meta.unique_together: for unique_together in parent_class._meta.unique_together:
if field_names.issuperset(set(unique_together)): # Skip if serializer does not map to all unique together sources
validator = UniqueTogetherValidator( if not set(source_map).issuperset(set(unique_together)):
queryset=parent_class._default_manager, continue
fields=unique_together
for source in unique_together:
assert len(source_map[source]) == 1, (
"Unable to create `UniqueTogetherValidator` for "
"`{model}.{field}` as `{serializer}` has multiple "
"fields ({fields}) that map to this model field. "
"Either remove the extra fields, or override "
"`Meta.validators` with a `UniqueTogetherValidator` "
"using the desired field names."
.format(
model=self.Meta.model.__name__,
serializer=self.__class__.__name__,
field=source,
fields=', '.join(source_map[source]),
)
) )
validators.append(validator)
field_names = tuple(source_map[f][0] for f in unique_together)
validator = UniqueTogetherValidator(
queryset=parent_class._default_manager,
fields=field_names
)
validators.append(validator)
return validators return validators
def get_unique_for_date_validators(self): def get_unique_for_date_validators(self):

View File

@ -1,8 +1,7 @@
from django.conf.urls import include, url from django.conf.urls import include, url
from django.urls import URLResolver, path, register_converter
from django.urls.resolvers import RoutePattern
from rest_framework.compat import (
URLResolver, get_regex_pattern, is_route_pattern, path, register_converter
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -37,7 +36,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
for urlpattern in urlpatterns: for urlpattern in urlpatterns:
if isinstance(urlpattern, URLResolver): if isinstance(urlpattern, URLResolver):
# Set of included URL patterns # Set of included URL patterns
regex = get_regex_pattern(urlpattern) regex = urlpattern.pattern.regex.pattern
namespace = urlpattern.namespace namespace = urlpattern.namespace
app_name = urlpattern.app_name app_name = urlpattern.app_name
kwargs = urlpattern.default_kwargs kwargs = urlpattern.default_kwargs
@ -48,7 +47,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
suffix_route) suffix_route)
# if the original pattern was a RoutePattern we need to preserve it # if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern): if isinstance(urlpattern.pattern, RoutePattern):
assert path is not None assert path is not None
route = str(urlpattern.pattern) route = str(urlpattern.pattern)
new_pattern = path(route, include((patterns, app_name), namespace), kwargs) new_pattern = path(route, include((patterns, app_name), namespace), kwargs)
@ -58,7 +57,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
ret.append(new_pattern) ret.append(new_pattern)
else: else:
# Regular URL pattern # Regular URL pattern
regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern regex = urlpattern.pattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern
view = urlpattern.callback view = urlpattern.callback
kwargs = urlpattern.default_args kwargs = urlpattern.default_args
name = urlpattern.name name = urlpattern.name
@ -67,7 +66,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
ret.append(urlpattern) ret.append(urlpattern)
# if the original pattern was a RoutePattern we need to preserve it # if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern): if isinstance(urlpattern.pattern, RoutePattern):
assert path is not None assert path is not None
assert suffix_route is not None assert suffix_route is not None
route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route

View File

@ -104,7 +104,7 @@ def get_field_kwargs(field_name, model_field):
if model_field.has_default() or model_field.blank or model_field.null: if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False kwargs['required'] = False
if model_field.null and not isinstance(model_field, models.NullBooleanField): if model_field.null:
kwargs['allow_null'] = True kwargs['allow_null'] = True
if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))): if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))):

View File

@ -91,6 +91,7 @@ setup(
'Framework :: Django', 'Framework :: Django',
'Framework :: Django :: 2.2', 'Framework :: Django :: 2.2',
'Framework :: Django :: 3.0', 'Framework :: Django :: 3.0',
'Framework :: Django :: 3.1',
'Intended Audience :: Developers', 'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License', 'License :: OSI Approved :: BSD License',
'Operating System :: OS Independent', 'Operating System :: OS Independent',

View File

@ -5,11 +5,12 @@ from django.conf.urls import include, url
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.urls import path
from rest_framework import ( from rest_framework import (
filters, generics, pagination, permissions, serializers filters, generics, pagination, permissions, serializers
) )
from rest_framework.compat import coreapi, coreschema, get_regex_pattern, path from rest_framework.compat import coreapi, coreschema
from rest_framework.decorators import action, api_view, schema from rest_framework.decorators import action, api_view, schema
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.routers import DefaultRouter, SimpleRouter
@ -1078,7 +1079,7 @@ class SchemaGenerationExclusionTests(TestCase):
inspector = EndpointEnumerator(self.patterns) inspector = EndpointEnumerator(self.patterns)
# Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test # Not pretty. Mimics internals of EndpointEnumerator to put should_include_endpoint under test
pairs = [(inspector.get_path_from_regex(get_regex_pattern(pattern)), pattern.callback) pairs = [(inspector.get_path_from_regex(pattern.pattern.regex.pattern), pattern.callback)
for pattern in self.patterns] for pattern in self.patterns]
should_include = [ should_include = [

View File

@ -13,7 +13,6 @@ from django.utils.timezone import activate, deactivate, override, utc
import rest_framework import rest_framework
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import ProhibitNullCharactersValidator
from rest_framework.fields import ( from rest_framework.fields import (
BuiltinSignatureError, DjangoImageField, is_simple_callable BuiltinSignatureError, DjangoImageField, is_simple_callable
) )
@ -747,7 +746,6 @@ class TestCharField(FieldValues):
field.run_validation(' ') field.run_validation(' ')
assert exc_info.value.detail == ['This field may not be blank.'] assert exc_info.value.detail == ['This field may not be blank.']
@pytest.mark.skipif(ProhibitNullCharactersValidator is None, reason="Skipped on Django < 2.0")
def test_null_bytes(self): def test_null_bytes(self):
field = serializers.CharField() field = serializers.CharField()
@ -762,8 +760,8 @@ class TestCharField(FieldValues):
field = serializers.CharField() field = serializers.CharField()
for code_point, expected_message in ( for code_point, expected_message in (
(0xD800, 'Surrogate characters are not allowed: U+D800.'), (0xD800, 'Surrogate characters are not allowed: U+D800.'),
(0xDFFF, 'Surrogate characters are not allowed: U+DFFF.'), (0xDFFF, 'Surrogate characters are not allowed: U+DFFF.'),
): ):
with pytest.raises(serializers.ValidationError) as exc_info: with pytest.raises(serializers.ValidationError) as exc_info:
field.run_validation(chr(code_point)) field.run_validation(chr(code_point))

View File

@ -1,7 +1,6 @@
import datetime import datetime
from importlib import reload as reload_module from importlib import reload as reload_module
import django
import pytest import pytest
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
@ -191,7 +190,6 @@ class SearchFilterTests(TestCase):
assert terms == ['asdf'] assert terms == ['asdf']
@pytest.mark.skipif(django.VERSION[:2] < (2, 2), reason="requires django 2.2 or higher")
def test_search_field_with_additional_transforms(self): def test_search_field_with_additional_transforms(self):
from django.test.utils import register_lookup from django.test.utils import register_lookup

View File

@ -182,7 +182,7 @@ class TestRegularFieldMappings(TestCase):
email_field = EmailField(max_length=100) email_field = EmailField(max_length=100)
float_field = FloatField() float_field = FloatField()
integer_field = IntegerField() integer_field = IntegerField()
null_boolean_field = NullBooleanField(required=False) null_boolean_field = BooleanField(allow_null=True, required=False)
positive_integer_field = IntegerField() positive_integer_field = IntegerField()
positive_small_integer_field = IntegerField() positive_small_integer_field = IntegerField()
slug_field = SlugField(allow_unicode=False, max_length=100) slug_field = SlugField(allow_unicode=False, max_length=100)
@ -236,6 +236,27 @@ class TestRegularFieldMappings(TestCase):
self.assertEqual(repr(NullableBooleanSerializer()), expected) self.assertEqual(repr(NullableBooleanSerializer()), expected)
def test_nullable_boolean_field_choices(self):
class NullableBooleanChoicesModel(models.Model):
CHECKLIST_OPTIONS = (
(None, 'Unknown'),
(True, 'Yes'),
(False, 'No'),
)
field = models.BooleanField(null=True, choices=CHECKLIST_OPTIONS)
class NullableBooleanChoicesSerializer(serializers.ModelSerializer):
class Meta:
model = NullableBooleanChoicesModel
fields = ['field']
serializer = NullableBooleanChoicesSerializer(data=dict(
field=None,
))
self.assertTrue(serializer.is_valid())
self.assertEqual(serializer.errors, {})
def test_method_field(self): def test_method_field(self):
""" """
Properties and methods on the model should be allowed as `Meta.fields` Properties and methods on the model should be allowed as `Meta.fields`

View File

@ -3,7 +3,6 @@ import unittest
from unittest import mock from unittest import mock
import django import django
import pytest
from django.conf import settings from django.conf import settings
from django.contrib.auth.models import AnonymousUser, Group, Permission, User from django.contrib.auth.models import AnonymousUser, Group, Permission, User
from django.db import models from django.db import models
@ -14,7 +13,6 @@ from rest_framework import (
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
status, views status, views
) )
from rest_framework.compat import PY36
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import BasicModel from tests.models import BasicModel
@ -607,7 +605,6 @@ class PermissionsCompositionTests(TestCase):
) )
assert composed_perm().has_permission(request, None) is True assert composed_perm().has_permission(request, None) is True
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_or_lazyness(self): def test_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
@ -616,19 +613,18 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() assert mock_allow.call_count == 1
mock_deny.assert_not_called() mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() assert mock_deny.call_count == 1
mock_allow.assert_called_once() assert mock_allow.call_count == 1
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_or_lazyness(self): def test_object_or_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
@ -637,19 +633,18 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny | permissions.IsAuthenticated) composed_perm = (permissions.AllowAny | permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_allow.assert_called_once() assert mock_allow.call_count == 1
mock_deny.assert_not_called() mock_deny.assert_not_called()
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated | permissions.AllowAny) composed_perm = (permissions.IsAuthenticated | permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, True) assert hasperm is True
mock_deny.assert_called_once() assert mock_deny.call_count == 1
mock_allow.assert_called_once() assert mock_allow.call_count == 1
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_and_lazyness(self): def test_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
@ -658,19 +653,18 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() assert mock_allow.call_count == 1
mock_deny.assert_called_once() assert mock_deny.call_count == 1
with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_permission(request, None) hasperm = composed_perm().has_permission(request, None)
self.assertIs(hasperm, False) assert hasperm is False
assert mock_deny.call_count == 1
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once()
@pytest.mark.skipif(not PY36, reason="assert_called_once() not available")
def test_object_and_lazyness(self): def test_object_and_lazyness(self):
request = factory.get('/1', format='json') request = factory.get('/1', format='json')
request.user = AnonymousUser() request.user = AnonymousUser()
@ -679,14 +673,14 @@ class PermissionsCompositionTests(TestCase):
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.AllowAny & permissions.IsAuthenticated) composed_perm = (permissions.AllowAny & permissions.IsAuthenticated)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
mock_allow.assert_called_once() assert mock_allow.call_count == 1
mock_deny.assert_called_once() assert mock_deny.call_count == 1
with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow: with mock.patch.object(permissions.AllowAny, 'has_object_permission', return_value=True) as mock_allow:
with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny: with mock.patch.object(permissions.IsAuthenticated, 'has_object_permission', return_value=False) as mock_deny:
composed_perm = (permissions.IsAuthenticated & permissions.AllowAny) composed_perm = (permissions.IsAuthenticated & permissions.AllowAny)
hasperm = composed_perm().has_object_permission(request, None, None) hasperm = composed_perm().has_object_permission(request, None, None)
self.assertIs(hasperm, False) assert hasperm is False
assert mock_deny.call_count == 1
mock_allow.assert_not_called() mock_allow.assert_not_called()
mock_deny.assert_called_once()

View File

@ -8,7 +8,6 @@ from django.test import TestCase, override_settings
from django.urls import resolve, reverse from django.urls import resolve, reverse
from rest_framework import permissions, serializers, viewsets from rest_framework import permissions, serializers, viewsets
from rest_framework.compat import get_regex_pattern
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.routers import DefaultRouter, SimpleRouter
@ -192,8 +191,7 @@ class TestCustomLookupFields(URLPatternsTestCase, TestCase):
def test_custom_lookup_field_route(self): def test_custom_lookup_field_route(self):
detail_route = notes_router.urls[-1] detail_route = notes_router.urls[-1]
detail_url_pattern = get_regex_pattern(detail_route) assert '<uuid>' in detail_route.pattern.regex.pattern
assert '<uuid>' in detail_url_pattern
def test_retrieve_lookup_field_list_view(self): def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/example/notes/') response = self.client.get('/example/notes/')
@ -229,7 +227,7 @@ class TestLookupValueRegex(TestCase):
def test_urls_limited_by_lookup_value_regex(self): def test_urls_limited_by_lookup_value_regex(self):
expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$'] expected = ['^notes/$', '^notes/(?P<uuid>[0-9a-f]{32})/$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == self.urls[idx].pattern.regex.pattern
@override_settings(ROOT_URLCONF='tests.test_routers') @override_settings(ROOT_URLCONF='tests.test_routers')
@ -249,8 +247,7 @@ class TestLookupUrlKwargs(URLPatternsTestCase, TestCase):
def test_custom_lookup_url_kwarg_route(self): def test_custom_lookup_url_kwarg_route(self):
detail_route = kwarged_notes_router.urls[-1] detail_route = kwarged_notes_router.urls[-1]
detail_url_pattern = get_regex_pattern(detail_route) assert '^notes/(?P<text>' in detail_route.pattern.regex.pattern
assert '^notes/(?P<text>' in detail_url_pattern
def test_retrieve_lookup_url_kwarg_detail_view(self): def test_retrieve_lookup_url_kwarg_detail_view(self):
response = self.client.get('/example2/notes/fo/') response = self.client.get('/example2/notes/fo/')
@ -273,7 +270,7 @@ class TestTrailingSlashIncluded(TestCase):
def test_urls_have_trailing_slash_by_default(self): def test_urls_have_trailing_slash_by_default(self):
expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$'] expected = ['^notes/$', '^notes/(?P<pk>[^/.]+)/$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == self.urls[idx].pattern.regex.pattern
class TestTrailingSlashRemoved(TestCase): class TestTrailingSlashRemoved(TestCase):
@ -288,7 +285,7 @@ class TestTrailingSlashRemoved(TestCase):
def test_urls_can_have_trailing_slash_removed(self): def test_urls_can_have_trailing_slash_removed(self):
expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/.]+)$']
for idx in range(len(expected)): for idx in range(len(expected)):
assert expected[idx] == get_regex_pattern(self.urls[idx]) assert expected[idx] == self.urls[idx].pattern.regex.pattern
class TestNameableRoot(TestCase): class TestNameableRoot(TestCase):

View File

@ -2,9 +2,9 @@ import unittest
from collections import namedtuple from collections import namedtuple
from django.test import TestCase from django.test import TestCase
from django.urls import Resolver404, include from django.urls import Resolver404, URLResolver, path, re_path
from django.urls.resolvers import RegexPattern
from rest_framework.compat import make_url_resolver, path
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.urlpatterns import format_suffix_patterns from rest_framework.urlpatterns import format_suffix_patterns
@ -27,7 +27,7 @@ class FormatSuffixTests(TestCase):
urlpatterns = format_suffix_patterns(urlpatterns, allowed=allowed) urlpatterns = format_suffix_patterns(urlpatterns, allowed=allowed)
except Exception: except Exception:
self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns")
resolver = make_url_resolver(r'^/', urlpatterns) resolver = URLResolver(RegexPattern(r'^/'), urlpatterns)
for test_path in test_paths: for test_path in test_paths:
try: try:
test_path, expected_resolved = test_path test_path, expected_resolved = test_path

View File

@ -344,6 +344,49 @@ class TestUniquenessTogetherValidation(TestCase):
] ]
} }
def test_default_validator_with_fields_with_source(self):
class TestSerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name')
class Meta:
model = UniquenessTogetherModel
fields = ['name', 'position']
serializer = TestSerializer()
expected = dedent("""
TestSerializer():
name = CharField(source='race_name')
position = IntegerField()
class Meta:
validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('name', 'position'))>]
""")
assert repr(serializer) == expected
def test_default_validator_with_multiple_fields_with_same_source(self):
class TestSerializer(serializers.ModelSerializer):
name = serializers.CharField(source='race_name')
other_name = serializers.CharField(source='race_name')
class Meta:
model = UniquenessTogetherModel
fields = ['name', 'other_name', 'position']
serializer = TestSerializer(data={
'name': 'foo',
'other_name': 'foo',
'position': 1,
})
with pytest.raises(AssertionError) as excinfo:
serializer.is_valid()
expected = (
"Unable to create `UniqueTogetherValidator` for "
"`UniquenessTogetherModel.race_name` as `TestSerializer` has "
"multiple fields (name, other_name) that map to this model field. "
"Either remove the extra fields, or override `Meta.validators` "
"with a `UniqueTogetherValidator` using the desired field names.")
assert str(excinfo.value) == expected
def test_allow_explict_override(self): def test_allow_explict_override(self):
""" """
Ensure validators can be explicitly removed.. Ensure validators can be explicitly removed..

View File

@ -2,6 +2,7 @@
envlist = envlist =
{py35,py36,py37}-django22, {py35,py36,py37}-django22,
{py36,py37,py38}-django30, {py36,py37,py38}-django30,
{py36,py37,py38}-django31,
{py36,py37,py38}-djangomaster, {py36,py37,py38}-djangomaster,
base,dist,lint,docs, base,dist,lint,docs,
@ -9,6 +10,7 @@ envlist =
DJANGO = DJANGO =
2.2: django22 2.2: django22
3.0: django30 3.0: django30
3.1: django31
master: djangomaster master: djangomaster
[testenv] [testenv]
@ -20,6 +22,7 @@ setenv =
deps = deps =
django22: Django>=2.2,<3.0 django22: Django>=2.2,<3.0
django30: Django>=3.0,<3.1 django30: Django>=3.0,<3.1
django31: Django>=3.1a1,<3.2
djangomaster: https://github.com/django/django/archive/master.tar.gz djangomaster: https://github.com/django/django/archive/master.tar.gz
-rrequirements/requirements-testing.txt -rrequirements/requirements-testing.txt
-rrequirements/requirements-optionals.txt -rrequirements/requirements-optionals.txt