mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
Merge remote-tracking branch 'origin/master' into documentation/include_translations_in_process
This commit is contained in:
commit
c82f6aa835
|
@ -247,6 +247,10 @@ Unauthenticated responses that are denied permission will result in an `HTTP 403
|
|||
|
||||
If you're using an AJAX style API with SessionAuthentication, you'll need to make sure you include a valid CSRF token for any "unsafe" HTTP method calls, such as `PUT`, `PATCH`, `POST` or `DELETE` requests. See the [Django CSRF documentation][csrf-ajax] for more details.
|
||||
|
||||
**Warning**: Always use Django's standard login view when creating login pages. This will ensure your login views are properly protected.
|
||||
|
||||
CSRF validation in REST framework works slightly differently to standard Django due to the need to support both session and non-session based authentication to the same views. This means that only authenticated requests require CSRF tokens, and anonymous requests may be sent without CSRF tokens. This behaviour is not suitable for login views, which should always have CSRF validation applied.
|
||||
|
||||
# Custom authentication
|
||||
|
||||
To implement a custom authentication scheme, subclass `BaseAuthentication` and override the `.authenticate(self, request)` method. The method should return a two-tuple of `(user, auth)` if authentication succeeds, or `None` otherwise.
|
||||
|
|
|
@ -302,6 +302,18 @@ Corresponds to `django.db.models.fields.TimeField`
|
|||
|
||||
Format strings may either be [Python strftime formats][strftime] which explicitly specify the format, or the special string `'iso-8601'`, which indicates that [ISO 8601][iso8601] style times should be used. (eg `'12:34:56.000000'`)
|
||||
|
||||
## DurationField
|
||||
|
||||
A Duration representation.
|
||||
Corresponds to `django.db.models.fields.DurationField`
|
||||
|
||||
The `validated_data` for these fields will contain a `datetime.timedelta` instance.
|
||||
The representation is a string following this format `'[DD] [HH:[MM:]]ss[.uuuuuu]'`.
|
||||
|
||||
**Note:** This field is only available with Django versions >= 1.8.
|
||||
|
||||
**Signature:** `DurationField()`
|
||||
|
||||
---
|
||||
|
||||
# Choice selection fields
|
||||
|
|
|
@ -7,6 +7,7 @@ versions of django/python, and compatibility wrappers around optional packages.
|
|||
from __future__ import unicode_literals
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.conf import settings
|
||||
from django.db import connection, transaction
|
||||
from django.utils.encoding import force_text
|
||||
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
|
||||
from django.utils import six
|
||||
|
@ -258,3 +259,27 @@ else:
|
|||
SHORT_SEPARATORS = (b',', b':')
|
||||
LONG_SEPARATORS = (b', ', b': ')
|
||||
INDENT_SEPARATORS = (b',', b': ')
|
||||
|
||||
|
||||
if django.VERSION >= (1, 8):
|
||||
from django.db.models import DurationField
|
||||
from django.utils.dateparse import parse_duration
|
||||
from django.utils.duration import duration_string
|
||||
else:
|
||||
DurationField = duration_string = parse_duration = None
|
||||
|
||||
|
||||
def set_rollback():
|
||||
if hasattr(transaction, 'set_rollback'):
|
||||
if connection.settings_dict.get('ATOMIC_REQUESTS', False):
|
||||
# If running in >=1.6 then mark a rollback as required,
|
||||
# and allow it to be handled by Django.
|
||||
transaction.set_rollback(True)
|
||||
elif transaction.is_managed():
|
||||
# Otherwise handle it explicitly if in managed mode.
|
||||
if transaction.is_dirty():
|
||||
transaction.rollback()
|
||||
transaction.leave_transaction_management()
|
||||
else:
|
||||
# transaction not managed
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,7 @@ from rest_framework import ISO_8601
|
|||
from rest_framework.compat import (
|
||||
EmailValidator, MinValueValidator, MaxValueValidator,
|
||||
MinLengthValidator, MaxLengthValidator, URLValidator, OrderedDict,
|
||||
unicode_repr, unicode_to_repr
|
||||
unicode_repr, unicode_to_repr, parse_duration, duration_string,
|
||||
)
|
||||
from rest_framework.exceptions import ValidationError
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -1003,6 +1003,29 @@ class TimeField(Field):
|
|||
return value.strftime(self.format)
|
||||
|
||||
|
||||
class DurationField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _('Duration has wrong format. Use one of these formats instead: {format}.'),
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if parse_duration is None:
|
||||
raise NotImplementedError(
|
||||
'DurationField not supported for django versions prior to 1.8')
|
||||
return super(DurationField, self).__init__(*args, **kwargs)
|
||||
|
||||
def to_internal_value(self, value):
|
||||
if isinstance(value, datetime.timedelta):
|
||||
return value
|
||||
parsed = parse_duration(value)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
self.fail('invalid', format='[DD] [HH:[MM:]]ss[.uuuuuu]')
|
||||
|
||||
def to_representation(self, value):
|
||||
return duration_string(value)
|
||||
|
||||
|
||||
# Choice types...
|
||||
|
||||
class ChoiceField(Field):
|
||||
|
@ -1060,7 +1083,11 @@ class MultipleChoiceField(ChoiceField):
|
|||
# We override the default field access in order to support
|
||||
# lists in HTML forms.
|
||||
if html.is_html_input(dictionary):
|
||||
return dictionary.getlist(self.field_name)
|
||||
ret = dictionary.getlist(self.field_name)
|
||||
if getattr(self.root, 'partial', False) and not ret:
|
||||
ret = empty
|
||||
return ret
|
||||
|
||||
return dictionary.get(self.field_name, empty)
|
||||
|
||||
def to_internal_value(self, data):
|
||||
|
|
|
@ -15,7 +15,11 @@ from django.db import models
|
|||
from django.db.models.fields import FieldDoesNotExist, Field as DjangoModelField
|
||||
from django.db.models import query
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework.compat import postgres_fields, unicode_to_repr
|
||||
from rest_framework.compat import (
|
||||
postgres_fields,
|
||||
unicode_to_repr,
|
||||
DurationField as ModelDurationField,
|
||||
)
|
||||
from rest_framework.utils import model_meta
|
||||
from rest_framework.utils.field_mapping import (
|
||||
get_url_kwargs, get_field_kwargs,
|
||||
|
@ -731,6 +735,8 @@ class ModelSerializer(Serializer):
|
|||
models.TimeField: TimeField,
|
||||
models.URLField: URLField,
|
||||
}
|
||||
if ModelDurationField is not None:
|
||||
serializer_field_mapping[ModelDurationField] = DurationField
|
||||
serializer_related_field = PrimaryKeyRelatedField
|
||||
serializer_url_field = HyperlinkedIdentityField
|
||||
serializer_choice_field = ChoiceField
|
||||
|
@ -1088,6 +1094,9 @@ class ModelSerializer(Serializer):
|
|||
if extra_kwargs.get('default') and kwargs.get('required') is False:
|
||||
kwargs.pop('required')
|
||||
|
||||
if kwargs.get('read_only', False):
|
||||
extra_kwargs.pop('required', None)
|
||||
|
||||
kwargs.update(extra_kwargs)
|
||||
|
||||
return kwargs
|
||||
|
|
|
@ -9,7 +9,7 @@ from django.utils.encoding import smart_text
|
|||
from django.utils.translation import ugettext_lazy as _
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from rest_framework import status, exceptions
|
||||
from rest_framework.compat import HttpResponseBase, View
|
||||
from rest_framework.compat import HttpResponseBase, View, set_rollback
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.settings import api_settings
|
||||
|
@ -71,16 +71,21 @@ def exception_handler(exc, context):
|
|||
else:
|
||||
data = {'detail': exc.detail}
|
||||
|
||||
set_rollback()
|
||||
return Response(data, status=exc.status_code, headers=headers)
|
||||
|
||||
elif isinstance(exc, Http404):
|
||||
msg = _('Not found.')
|
||||
data = {'detail': six.text_type(msg)}
|
||||
|
||||
set_rollback()
|
||||
return Response(data, status=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
elif isinstance(exc, PermissionDenied):
|
||||
msg = _('Permission denied.')
|
||||
data = {'detail': six.text_type(msg)}
|
||||
|
||||
set_rollback()
|
||||
return Response(data, status=status.HTTP_403_FORBIDDEN)
|
||||
|
||||
# Note: Unhandled exceptions will raise a 500 error.
|
||||
|
|
110
tests/test_atomic_requests.py
Normal file
110
tests/test_atomic_requests.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from django.db import connection, connections, transaction
|
||||
from django.test import TestCase
|
||||
from django.utils.unittest import skipUnless
|
||||
from rest_framework import status
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
from tests.models import BasicModel
|
||||
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class BasicView(APIView):
|
||||
def post(self, request, *args, **kwargs):
|
||||
BasicModel.objects.create()
|
||||
return Response({'method': 'GET'})
|
||||
|
||||
|
||||
class ErrorView(APIView):
|
||||
def post(self, request, *args, **kwargs):
|
||||
BasicModel.objects.create()
|
||||
raise Exception
|
||||
|
||||
|
||||
class APIExceptionView(APIView):
|
||||
def post(self, request, *args, **kwargs):
|
||||
BasicModel.objects.create()
|
||||
raise APIException
|
||||
|
||||
|
||||
@skipUnless(connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints.")
|
||||
class DBTransactionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.view = BasicView.as_view()
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
||||
|
||||
def tearDown(self):
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
||||
|
||||
def test_no_exception_conmmit_transaction(self):
|
||||
request = factory.post('/')
|
||||
|
||||
with self.assertNumQueries(1):
|
||||
response = self.view(request)
|
||||
self.assertFalse(transaction.get_rollback())
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
assert BasicModel.objects.count() == 1
|
||||
|
||||
|
||||
@skipUnless(connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints.")
|
||||
class DBTransactionErrorTests(TestCase):
|
||||
def setUp(self):
|
||||
self.view = ErrorView.as_view()
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
||||
|
||||
def tearDown(self):
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
||||
|
||||
def test_generic_exception_delegate_transaction_management(self):
|
||||
"""
|
||||
Transaction is eventually managed by outer-most transaction atomic
|
||||
block. DRF do not try to interfere here.
|
||||
|
||||
We let django deal with the transaction when it will catch the Exception.
|
||||
"""
|
||||
request = factory.post('/')
|
||||
with self.assertNumQueries(3):
|
||||
# 1 - begin savepoint
|
||||
# 2 - insert
|
||||
# 3 - release savepoint
|
||||
with transaction.atomic():
|
||||
self.assertRaises(Exception, self.view, request)
|
||||
self.assertFalse(transaction.get_rollback())
|
||||
assert BasicModel.objects.count() == 1
|
||||
|
||||
|
||||
@skipUnless(connection.features.uses_savepoints,
|
||||
"'atomic' requires transactions and savepoints.")
|
||||
class DBTransactionAPIExceptionTests(TestCase):
|
||||
def setUp(self):
|
||||
self.view = APIExceptionView.as_view()
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
||||
|
||||
def tearDown(self):
|
||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
||||
|
||||
def test_api_exception_rollback_transaction(self):
|
||||
"""
|
||||
Transaction is rollbacked by our transaction atomic block.
|
||||
"""
|
||||
request = factory.post('/')
|
||||
num_queries = (4 if getattr(connection.features,
|
||||
'can_release_savepoints', False) else 3)
|
||||
with self.assertNumQueries(num_queries):
|
||||
# 1 - begin savepoint
|
||||
# 2 - insert
|
||||
# 3 - rollback savepoint
|
||||
# 4 - release savepoint (django>=1.8 only)
|
||||
with transaction.atomic():
|
||||
response = self.view(request)
|
||||
self.assertTrue(transaction.get_rollback())
|
||||
self.assertEqual(response.status_code,
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
assert BasicModel.objects.count() == 0
|
|
@ -1,6 +1,7 @@
|
|||
from decimal import Decimal
|
||||
from django.utils import timezone
|
||||
from rest_framework import serializers
|
||||
import rest_framework
|
||||
import datetime
|
||||
import django
|
||||
import pytest
|
||||
|
@ -221,6 +222,14 @@ class TestInvalidErrorKey:
|
|||
assert str(exc_info.value) == expected
|
||||
|
||||
|
||||
class MockHTMLDict(dict):
|
||||
"""
|
||||
This class mocks up a dictionary like object, that behaves
|
||||
as if it was returned for multipart or urlencoded data.
|
||||
"""
|
||||
getlist = None
|
||||
|
||||
|
||||
class TestBooleanHTMLInput:
|
||||
def setup(self):
|
||||
class TestSerializer(serializers.Serializer):
|
||||
|
@ -234,21 +243,11 @@ class TestBooleanHTMLInput:
|
|||
"""
|
||||
# This class mocks up a dictionary like object, that behaves
|
||||
# as if it was returned for multipart or urlencoded data.
|
||||
class MockHTMLDict(dict):
|
||||
getlist = None
|
||||
serializer = self.Serializer(data=MockHTMLDict())
|
||||
assert serializer.is_valid()
|
||||
assert serializer.validated_data == {'archived': False}
|
||||
|
||||
|
||||
class MockHTMLDict(dict):
|
||||
"""
|
||||
This class mocks up a dictionary like object, that behaves
|
||||
as if it was returned for multipart or urlencoded data.
|
||||
"""
|
||||
getlist = None
|
||||
|
||||
|
||||
class TestHTMLInput:
|
||||
def test_empty_html_charfield(self):
|
||||
class TestSerializer(serializers.Serializer):
|
||||
|
@ -905,6 +904,29 @@ class TestNoOutputFormatTimeField(FieldValues):
|
|||
field = serializers.TimeField(format=None)
|
||||
|
||||
|
||||
@pytest.mark.skipif(django.VERSION < (1, 8),
|
||||
reason='DurationField is only available for django1.8+')
|
||||
class TestDurationField(FieldValues):
|
||||
"""
|
||||
Valid and invalid values for `DurationField`.
|
||||
"""
|
||||
valid_inputs = {
|
||||
'13': datetime.timedelta(seconds=13),
|
||||
'3 08:32:01.000123': datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
|
||||
'08:01': datetime.timedelta(minutes=8, seconds=1),
|
||||
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123),
|
||||
}
|
||||
invalid_inputs = {
|
||||
'abc': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
|
||||
'3 08:32 01.123': ['Duration has wrong format. Use one of these formats instead: [DD] [HH:[MM:]]ss[.uuuuuu].'],
|
||||
}
|
||||
outputs = {
|
||||
datetime.timedelta(days=3, hours=8, minutes=32, seconds=1, microseconds=123): '3 08:32:01.000123',
|
||||
}
|
||||
if django.VERSION >= (1, 8):
|
||||
field = serializers.DurationField()
|
||||
|
||||
|
||||
# Choice types...
|
||||
|
||||
class TestChoiceField(FieldValues):
|
||||
|
@ -1017,6 +1039,15 @@ class TestMultipleChoiceField(FieldValues):
|
|||
]
|
||||
)
|
||||
|
||||
def test_against_partial_and_full_updates(self):
|
||||
# serializer = self.Serializer(data=MockHTMLDict())
|
||||
from django.http import QueryDict
|
||||
field = serializers.MultipleChoiceField(choices=(('a', 'a'), ('b', 'b')))
|
||||
field.partial = False
|
||||
assert field.get_value(QueryDict({})) == []
|
||||
field.partial = True
|
||||
assert field.get_value(QueryDict({})) == rest_framework.fields.empty
|
||||
|
||||
|
||||
# File serializers...
|
||||
|
||||
|
|
|
@ -6,13 +6,15 @@ These tests deal with ensuring that we correctly map the model fields onto
|
|||
an appropriate set of serializer fields for each case.
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
import django
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.core.validators import MaxValueValidator, MinValueValidator, MinLengthValidator
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
import pytest
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import unicode_repr
|
||||
from rest_framework.compat import unicode_repr, DurationField as ModelDurationField
|
||||
|
||||
|
||||
def dedent(blocktext):
|
||||
|
@ -284,6 +286,28 @@ class TestRegularFieldMappings(TestCase):
|
|||
ChildSerializer().fields
|
||||
|
||||
|
||||
@pytest.mark.skipif(django.VERSION < (1, 8),
|
||||
reason='DurationField is only available for django1.8+')
|
||||
class TestDurationFieldMapping(TestCase):
|
||||
def test_duration_field(self):
|
||||
class DurationFieldModel(models.Model):
|
||||
"""
|
||||
A model that defines DurationField.
|
||||
"""
|
||||
duration_field = ModelDurationField()
|
||||
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = DurationFieldModel
|
||||
|
||||
expected = dedent("""
|
||||
TestSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
duration_field = DurationField()
|
||||
""")
|
||||
self.assertEqual(unicode_repr(TestSerializer()), expected)
|
||||
|
||||
|
||||
# Tests for relational field mappings.
|
||||
# ------------------------------------
|
||||
|
||||
|
@ -316,6 +340,14 @@ class RelationalModel(models.Model):
|
|||
through = models.ManyToManyField(ThroughTargetModel, through=Supplementary, related_name='reverse_through')
|
||||
|
||||
|
||||
class UniqueTogetherModel(models.Model):
|
||||
foreign_key = models.ForeignKey(ForeignKeyTargetModel, related_name='unique_foreign_key')
|
||||
one_to_one = models.OneToOneField(OneToOneTargetModel, related_name='unique_one_to_one')
|
||||
|
||||
class Meta:
|
||||
unique_together = ("foreign_key", "one_to_one")
|
||||
|
||||
|
||||
class TestRelationalFieldMappings(TestCase):
|
||||
def test_pk_relations(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
|
@ -395,6 +427,32 @@ class TestRelationalFieldMappings(TestCase):
|
|||
""")
|
||||
self.assertEqual(unicode_repr(TestSerializer()), expected)
|
||||
|
||||
def test_nested_unique_together_relations(self):
|
||||
class TestSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = UniqueTogetherModel
|
||||
depth = 1
|
||||
expected = dedent("""
|
||||
TestSerializer():
|
||||
url = HyperlinkedIdentityField(view_name='uniquetogethermodel-detail')
|
||||
foreign_key = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='foreignkeytargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
one_to_one = NestedSerializer(read_only=True):
|
||||
url = HyperlinkedIdentityField(view_name='onetoonetargetmodel-detail')
|
||||
name = CharField(max_length=100)
|
||||
class Meta:
|
||||
validators = [<UniqueTogetherValidator(queryset=UniqueTogetherModel.objects.all(), fields=('foreign_key', 'one_to_one'))>]
|
||||
""")
|
||||
if six.PY2:
|
||||
# This case is also too awkward to resolve fully across both py2
|
||||
# and py3. (See above)
|
||||
expected = expected.replace(
|
||||
"('foreign_key', 'one_to_one')",
|
||||
"(u'foreign_key', u'one_to_one')"
|
||||
)
|
||||
self.assertEqual(unicode_repr(TestSerializer()), expected)
|
||||
|
||||
def test_pk_reverse_foreign_key(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
|
|
Loading…
Reference in New Issue
Block a user