mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 16:40:03 +03:00
implement ValueFromContext helper as a more generic version of CurrentUserDefault
This commit is contained in:
parent
5bc70e9d32
commit
b2ec2e5b0f
|
@ -283,14 +283,25 @@ class CreateOnlyDefault:
|
|||
return '%s(%s)' % (self.__class__.__name__, repr(self.default))
|
||||
|
||||
|
||||
class CurrentUserDefault:
|
||||
class ValueFromContext:
|
||||
requires_context = True
|
||||
|
||||
def __init__(self, context_name):
|
||||
self.context_name = context_name
|
||||
|
||||
def __call__(self, serializer_field):
|
||||
return serializer_field.context['request'].user
|
||||
return serializer_field.context[self.context_name]
|
||||
|
||||
def __repr__(self):
|
||||
return '%s()' % self.__class__.__name__
|
||||
return "%s()" % self.__class__.__name__
|
||||
|
||||
|
||||
class CurrentUserDefault(ValueFromContext):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(context_name="request")
|
||||
|
||||
def __call__(self, serializer_field):
|
||||
return super().__call__(serializer_field).user
|
||||
|
||||
|
||||
class SkipField(Exception):
|
||||
|
|
|
@ -15,7 +15,8 @@ from django.utils.timezone import activate, deactivate, override, utc
|
|||
import rest_framework
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.fields import (
|
||||
BuiltinSignatureError, CurrentUserDefault, DjangoImageField, is_simple_callable
|
||||
BuiltinSignatureError, CurrentUserDefault, DjangoImageField,
|
||||
ValueFromContext, is_simple_callable
|
||||
)
|
||||
|
||||
# Tests for helper functions.
|
||||
|
@ -2407,3 +2408,27 @@ class TestCurrentUserDefault:
|
|||
with pytest.raises(KeyError) as exc_info:
|
||||
serializer.is_valid()
|
||||
assert str(exc_info.value) == "'request'"
|
||||
|
||||
|
||||
class ValueFromContextSerializer(serializers.Serializer):
|
||||
vocalization = serializers.HiddenField(default=ValueFromContext("vocalization"))
|
||||
|
||||
|
||||
class TestValueFromContext:
|
||||
def test_context_set(self):
|
||||
serializer = ValueFromContextSerializer(data={}, context={"vocalization": "meow"})
|
||||
serializer.is_valid()
|
||||
field = serializer.fields["vocalization"]
|
||||
assert field.get_default() == "meow"
|
||||
|
||||
def test_context_set_none(self):
|
||||
serializer = ValueFromContextSerializer(data={}, context={"vocalization": None})
|
||||
serializer.is_valid()
|
||||
field = serializer.fields["vocalization"]
|
||||
assert field.get_default() is None
|
||||
|
||||
def test_missing_context(self):
|
||||
serializer = ValueFromContextSerializer(data={}, context={})
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
serializer.is_valid()
|
||||
assert str(exc_info.value) == "'vocalization'"
|
||||
|
|
Loading…
Reference in New Issue
Block a user