Merge branch 'master' into fixup-modelserializer-tests

This commit is contained in:
Carlton Gibson 2017-11-22 10:16:20 +01:00 committed by GitHub
commit bee2bd31eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 159 additions and 121 deletions

View File

@ -356,8 +356,6 @@ Corresponds to `django.db.models.fields.DurationField`
The `validated_data` for these fields will contain a `datetime.timedelta` instance.
The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`.
**Note:** This field is only available with Django versions >= 1.8.
**Signature:** `DurationField()`
---
@ -681,4 +679,4 @@ The [django-rest-framework-hstore][django-rest-framework-hstore] package provide
[django-rest-framework-gis]: https://github.com/djangonauts/django-rest-framework-gis
[django-rest-framework-hstore]: https://github.com/djangonauts/django-rest-framework-hstore
[django-hstore]: https://github.com/djangonauts/django-hstore
[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes
[python-decimal-rounding-modes]: https://docs.python.org/3/library/decimal.html#rounding-modes

View File

@ -493,6 +493,8 @@ The names in the `fields` and `exclude` attributes will normally map to model fi
Alternatively names in the `fields` options can map to properties or methods which take no arguments that exist on the model class.
Since version 3.3.0, it is **mandatory** to provide one of the attributes `fields` or `exclude`.
## Specifying nested serialization
The default `ModelSerializer` uses primary keys for relationships, but you can also easily generate nested representations using the `depth` option:
@ -1009,7 +1011,7 @@ Takes the object instance that requires serialization, and should return a primi
Takes the unvalidated incoming data as input and should return the validated data that will be made available as `serializer.validated_data`. The return value will also be passed to the `.create()` or `.update()` methods if `.save()` is called on the serializer class.
If any of the validation fails, then the method should raise a `serializers.ValidationError(errors)`. The `errors` argument should be a dictionary mapping field names (or `settings.NON_FIELD_ERRORS_KEY`) to a list of error messages. If you don't need to alter deserialization behavior and instead want to provide object-level validation, it's recommended that you intead override the [`.validate()`](#object-level-validation) method.
If any of the validation fails, then the method should raise a `serializers.ValidationError(errors)`. The `errors` argument should be a dictionary mapping field names (or `settings.NON_FIELD_ERRORS_KEY`) to a list of error messages. If you don't need to alter deserialization behavior and instead want to provide object-level validation, it's recommended that you instead override the [`.validate()`](#object-level-validation) method.
The `data` argument passed to this method will normally be the value of `request.data`, so the datatype it provides will depend on the parser classes you have configured for your API.

View File

@ -120,10 +120,10 @@ If you're intending to use the browsable API you'll probably also want to add RE
urlpatterns = [
...
url(r'^api-auth/', include('rest_framework.urls', namespace='rest_framework'))
url(r'^api-auth/', include('rest_framework.urls'))
]
Note that the URL path can be whatever you want, but you must include `'rest_framework.urls'` with the `'rest_framework'` namespace. You may leave out the namespace in Django 1.9+, and REST framework will set it for you.
Note that the URL path can be whatever you want.
## Example

View File

@ -48,8 +48,6 @@ We'll need to add our new `snippets` app and the `rest_framework` app to `INSTAL
'snippets.apps.SnippetsConfig',
)
Please note that if you're using Django <1.9, you need to replace `snippets.apps.SnippetsConfig` with `snippets`.
Okay, we're ready to roll.
## Creating a model to work with

View File

@ -142,11 +142,10 @@ Add the following import at the top of the file:
And, at the end of the file, add a pattern to include the login and logout views for the browsable API.
urlpatterns += [
url(r'^api-auth/', include('rest_framework.urls',
namespace='rest_framework')),
url(r'^api-auth/', include('rest_framework.urls'),
]
The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use. The only restriction is that the included urls must use the `'rest_framework'` namespace. In Django 1.9+, REST framework will set the namespace, so you may leave it out.
The `r'^api-auth/'` part of pattern can actually be whatever URL you want to use.
Now if you open up the browser again and refresh the page you'll see a 'Login' link in the top right of the page. If you log in as one of the users you created earlier, you'll be able to create code snippets again.

View File

@ -20,14 +20,10 @@ class AuthTokenSerializer(serializers.Serializer):
user = authenticate(request=self.context.get('request'),
username=username, password=password)
if user:
# From Django 1.10 onwards the `authenticate` call simply
# returns `None` for is_active=False users.
# (Assuming the default `ModelBackend` authentication backend.)
if not user.is_active:
msg = _('User account is disabled.')
raise serializers.ValidationError(msg, code='authorization')
else:
# The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication
# backend.)
if not user:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg, code='authorization')
else:

View File

@ -12,7 +12,7 @@ from django.apps import apps
from django.conf import settings
from django.core import validators
from django.core.exceptions import ImproperlyConfigured
from django.db import connection, models, transaction
from django.db import models
from django.utils import six
from django.views.generic import View
@ -250,7 +250,7 @@ else:
# pytz is required from Django 1.11. Remove when dropping Django 1.10 support.
try:
import pytz # noqa
import pytz # noqa
from pytz.exceptions import InvalidTimeError
except ImportError:
InvalidTimeError = Exception
@ -297,23 +297,6 @@ class MaxLengthValidator(CustomValidatorMessage, validators.MaxLengthValidator):
pass
def set_rollback():
if hasattr(transaction, 'set_rollback'):
if connection.settings_dict.get('ATOMIC_REQUESTS', False):
# If running in >=1.6 then mark a rollback as required,
# and allow it to be handled by Django.
if connection.in_atomic_block:
transaction.set_rollback(True)
elif transaction.is_managed():
# Otherwise handle it explicitly if in managed mode.
if transaction.is_dirty():
transaction.rollback()
transaction.leave_transaction_management()
else:
# transaction not managed
pass
def authenticate(request=None, **credentials):
from django.contrib.auth import authenticate
if django.VERSION < (1, 11):

View File

@ -666,7 +666,7 @@ class BrowsableAPIRenderer(BaseRenderer):
paginator = None
csrf_cookie_name = settings.CSRF_COOKIE_NAME
csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8
csrf_header_name = settings.CSRF_HEADER_NAME
if csrf_header_name.startswith('HTTP_'):
csrf_header_name = csrf_header_name[5:]
csrf_header_name = csrf_header_name.replace('_', '-')

View File

@ -250,9 +250,10 @@ class Request(object):
else:
self._full_data = self._data
# copy files refs to the underlying request so that closable
# copy data & files refs to the underlying request so that closable
# objects are handled appropriately.
self._request._files = self._files
self._request._post = self.POST
self._request._files = self.FILES
def _load_stream(self):
"""

View File

@ -368,7 +368,7 @@ class AutoSchema(ViewInspector):
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
return method.lower in ["get", "put", "patch", "delete"]
return method.lower() in ["get", "put", "patch", "delete"]
def get_filter_fields(self, path, method):
if not self._allows_filters(path, method):

View File

@ -1102,6 +1102,17 @@ class ModelSerializer(Serializer):
if exclude is not None:
# If `Meta.exclude` is included, then remove those fields.
for field_name in exclude:
assert field_name not in self._declared_fields, (
"Cannot both declare the field '{field_name}' and include "
"it in the {serializer_class} 'exclude' option. Remove the "
"field or, if inherited from a parent serializer, disable "
"with `{field_name} = None`."
.format(
field_name=field_name,
serializer_class=self.__class__.__name__
)
)
assert field_name in fields, (
"The field '{field_name}' was included on serializer "
"{serializer_class} in the 'exclude' option, but does "

View File

@ -6,11 +6,10 @@ your API requires authentication:
urlpatterns = [
...
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework'))
url(r'^auth/', include('rest_framework.urls'))
]
In Django versions older than 1.9, the urls must be namespaced as 'rest_framework',
and you should make sure your authentication settings include `SessionAuthentication`.
You should make sure your authentication settings include `SessionAuthentication`.
"""
from __future__ import unicode_literals

View File

@ -105,18 +105,13 @@ def _get_reverse_relationships(opts):
"""
Returns an `OrderedDict` of field names to `RelationInfo`.
"""
# Note that we have a hack here to handle internal API differences for
# this internal API across Django 1.7 -> Django 1.8.
# See: https://code.djangoproject.com/ticket/24208
reverse_relations = OrderedDict()
all_related_objects = [r for r in opts.related_objects if not r.field.many_to_many]
for relation in all_related_objects:
accessor_name = relation.get_accessor_name()
related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related_model=related,
related_model=relation.related_model,
to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field),
has_through_model=False,
@ -127,10 +122,9 @@ def _get_reverse_relationships(opts):
all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many]
for relation in all_related_many_to_many_objects:
accessor_name = relation.get_accessor_name()
related = getattr(relation, 'related_model', relation.model)
reverse_relations[accessor_name] = RelationInfo(
model_field=None,
related_model=related,
related_model=relation.related_model,
to_many=True,
# manytomany do not have to_fields
to_field=None,

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.db import models
from django.db import connection, models, transaction
from django.http import Http404
from django.http.response import HttpResponseBase
from django.utils import six
@ -16,7 +16,6 @@ from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View
from rest_framework import exceptions, status
from rest_framework.compat import set_rollback
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.schemas import AutoSchema
@ -55,6 +54,12 @@ def get_view_description(view_cls, html=False):
return description
def set_rollback():
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
def exception_handler(exc, context):
"""
Returns the response that should be used for any given exception.

View File

@ -120,13 +120,12 @@ class DBTransactionAPIExceptionTests(TestCase):
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
num_queries = (4 if getattr(connection.features,
'can_release_savepoints', False) else 3)
num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint (django>=1.8 only)
# 4 - release savepoint
with transaction.atomic():
response = self.view(request)
assert transaction.get_rollback()

View File

@ -1,44 +0,0 @@
from django.test import TestCase
from rest_framework import compat
class CompatTests(TestCase):
def setUp(self):
self.original_django_version = compat.django.VERSION
self.original_transaction = compat.transaction
def tearDown(self):
compat.django.VERSION = self.original_django_version
compat.transaction = self.original_transaction
def test_set_rollback_for_transaction_in_managed_mode(self):
class MockTransaction(object):
called_rollback = False
called_leave_transaction_management = False
def is_managed(self):
return True
def is_dirty(self):
return True
def rollback(self):
self.called_rollback = True
def leave_transaction_management(self):
self.called_leave_transaction_management = True
dirty_mock_transaction = MockTransaction()
compat.transaction = dirty_mock_transaction
compat.set_rollback()
assert dirty_mock_transaction.called_rollback is True
assert dirty_mock_transaction.called_leave_transaction_management is True
clean_mock_transaction = MockTransaction()
clean_mock_transaction.is_dirty = lambda: False
compat.transaction = clean_mock_transaction
compat.set_rollback()
assert clean_mock_transaction.called_rollback is False
assert clean_mock_transaction.called_leave_transaction_management is True

View File

@ -5,7 +5,6 @@ import unittest
import uuid
from decimal import ROUND_DOWN, ROUND_UP, Decimal
import django
import pytest
from django.http import QueryDict
from django.test import TestCase, override_settings
@ -1197,11 +1196,6 @@ class TestDateTimeField(FieldValues):
field = serializers.DateTimeField(default_timezone=utc)
if django.VERSION[:2] <= (1, 8):
# Doesn't raise an error on earlier versions of Django
TestDateTimeField.invalid_inputs.pop('2018-08-16 22:00-24:00')
class TestCustomInputFormatDateTimeField(FieldValues):
"""
Valid and invalid values for `DateTimeField` with a custom input format.

View File

@ -1,9 +1,7 @@
from __future__ import unicode_literals
import datetime
import unittest
import django
import pytest
from django.core.exceptions import ImproperlyConfigured
from django.db import models
@ -291,7 +289,6 @@ class SearchFilterToManyTests(TestCase):
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
@unittest.skipIf(django.VERSION < (1, 9), "Django 1.8 does not support transforms")
def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all()

View File

@ -1,34 +1,76 @@
from django.conf.urls import url
from django.contrib.auth.models import User
from django.http import HttpRequest
from django.test import override_settings
from rest_framework.authentication import TokenAuthentication
from rest_framework.authtoken.models import Token
from rest_framework.request import is_form_media_type
from rest_framework.response import Response
from rest_framework.test import APITestCase
from rest_framework.views import APIView
class PostView(APIView):
def post(self, request):
return Response(data=request.data, status=200)
urlpatterns = [
url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))),
url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))),
url(r'^post$', PostView.as_view()),
]
class MyMiddleware(object):
class RequestUserMiddleware(object):
def __init__(self, get_response):
self.get_response = get_response
def process_response(self, request, response):
def __call__(self, request):
response = self.get_response(request)
assert hasattr(request, 'user'), '`user` is not set on request'
assert request.user.is_authenticated(), '`user` is not authenticated'
assert request.user.is_authenticated, '`user` is not authenticated'
return response
class RequestPOSTMiddleware(object):
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
assert isinstance(request, HttpRequest)
# Parse body with underlying Django request
request.body
# Process request with DRF view
response = self.get_response(request)
# Ensure request.POST is set as appropriate
if is_form_media_type(request.content_type):
assert request.POST == {'foo': ['bar']}
else:
assert request.POST == {}
return response
@override_settings(ROOT_URLCONF='tests.test_middleware')
class TestMiddleware(APITestCase):
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',))
def test_middleware_can_access_user_when_processing_response(self):
user = User.objects.create_user('john', 'john@example.com', 'password')
key = 'abcd1234'
Token.objects.create(key=key, user=user)
with self.settings(
MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',)
):
auth = 'Token ' + key
self.client.get('/', HTTP_AUTHORIZATION=auth)
self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key)
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',))
def test_middleware_can_access_request_post_when_processing_response(self):
response = self.client.post('/post', {'foo': 'bar'})
assert response.status_code == 200
response = self.client.post('/post', {'foo': 'bar'}, format='json')
assert response.status_code == 200

View File

@ -863,6 +863,22 @@ class TestSerializerMetaClass(TestCase):
with self.assertRaisesMessage(AssertionError, msginitial):
ExampleSerializer().fields
def test_declared_fields_with_exclude_option(self):
class ExampleSerializer(serializers.ModelSerializer):
text = serializers.CharField()
class Meta:
model = MetaClassTestModel
exclude = ('text',)
expected = (
"Cannot both declare the field 'text' and include it in the "
"ExampleSerializer 'exclude' option. Remove the field or, if "
"inherited from a parent serializer, disable with `text = None`."
)
with self.assertRaisesMessage(AssertionError, expected):
ExampleSerializer().fields
class Issue2704TestCase(TestCase):
def test_queryset_all(self):

View File

@ -951,3 +951,51 @@ def test_head_and_options_methods_are_excluded():
assert inspector.should_include_endpoint(path, callback)
assert inspector.get_allowed_methods(callback) == ["GET"]
class TestAutoSchemaAllowsFilters(object):
class MockAPIView(APIView):
filter_backends = [filters.OrderingFilter]
def _test(self, method):
view = self.MockAPIView()
fields = view.schema.get_filter_fields('', method)
field_names = [f.name for f in fields]
return 'ordering' in field_names
def test_get(self):
assert self._test('get')
def test_GET(self):
assert self._test('GET')
def test_put(self):
assert self._test('put')
def test_PUT(self):
assert self._test('PUT')
def test_patch(self):
assert self._test('patch')
def test_PATCH(self):
assert self._test('PATCH')
def test_delete(self):
assert self._test('delete')
def test_DELETE(self):
assert self._test('DELETE')
def test_post(self):
assert not self._test('post')
def test_POST(self):
assert not self._test('POST')
def test_foo(self):
assert not self._test('foo')
def test_FOO(self):
assert not self._test('FOO')

View File

@ -515,7 +515,7 @@ class TestSerializerValidationWithCompiledRegexField:
assert serializer.errors == {}
class Test2505Regression:
class Test2555Regression:
def test_serializer_context(self):
class NestedSerializer(serializers.Serializer):
def __init__(self, *args, **kwargs):