mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 16:40:03 +03:00
implement ValueFromContext helper as a more generic version of CurrentUserDefault
This commit is contained in:
parent
5bc70e9d32
commit
b2ec2e5b0f
|
@ -283,14 +283,25 @@ class CreateOnlyDefault:
|
||||||
return '%s(%s)' % (self.__class__.__name__, repr(self.default))
|
return '%s(%s)' % (self.__class__.__name__, repr(self.default))
|
||||||
|
|
||||||
|
|
||||||
class CurrentUserDefault:
|
class ValueFromContext:
|
||||||
requires_context = True
|
requires_context = True
|
||||||
|
|
||||||
|
def __init__(self, context_name):
|
||||||
|
self.context_name = context_name
|
||||||
|
|
||||||
def __call__(self, serializer_field):
|
def __call__(self, serializer_field):
|
||||||
return serializer_field.context['request'].user
|
return serializer_field.context[self.context_name]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '%s()' % self.__class__.__name__
|
return "%s()" % self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentUserDefault(ValueFromContext):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(context_name="request")
|
||||||
|
|
||||||
|
def __call__(self, serializer_field):
|
||||||
|
return super().__call__(serializer_field).user
|
||||||
|
|
||||||
|
|
||||||
class SkipField(Exception):
|
class SkipField(Exception):
|
||||||
|
|
|
@ -15,7 +15,8 @@ from django.utils.timezone import activate, deactivate, override, utc
|
||||||
import rest_framework
|
import rest_framework
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
from rest_framework.fields import (
|
from rest_framework.fields import (
|
||||||
BuiltinSignatureError, CurrentUserDefault, DjangoImageField, is_simple_callable
|
BuiltinSignatureError, CurrentUserDefault, DjangoImageField,
|
||||||
|
ValueFromContext, is_simple_callable
|
||||||
)
|
)
|
||||||
|
|
||||||
# Tests for helper functions.
|
# Tests for helper functions.
|
||||||
|
@ -2407,3 +2408,27 @@ class TestCurrentUserDefault:
|
||||||
with pytest.raises(KeyError) as exc_info:
|
with pytest.raises(KeyError) as exc_info:
|
||||||
serializer.is_valid()
|
serializer.is_valid()
|
||||||
assert str(exc_info.value) == "'request'"
|
assert str(exc_info.value) == "'request'"
|
||||||
|
|
||||||
|
|
||||||
|
class ValueFromContextSerializer(serializers.Serializer):
|
||||||
|
vocalization = serializers.HiddenField(default=ValueFromContext("vocalization"))
|
||||||
|
|
||||||
|
|
||||||
|
class TestValueFromContext:
|
||||||
|
def test_context_set(self):
|
||||||
|
serializer = ValueFromContextSerializer(data={}, context={"vocalization": "meow"})
|
||||||
|
serializer.is_valid()
|
||||||
|
field = serializer.fields["vocalization"]
|
||||||
|
assert field.get_default() == "meow"
|
||||||
|
|
||||||
|
def test_context_set_none(self):
|
||||||
|
serializer = ValueFromContextSerializer(data={}, context={"vocalization": None})
|
||||||
|
serializer.is_valid()
|
||||||
|
field = serializer.fields["vocalization"]
|
||||||
|
assert field.get_default() is None
|
||||||
|
|
||||||
|
def test_missing_context(self):
|
||||||
|
serializer = ValueFromContextSerializer(data={}, context={})
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
serializer.is_valid()
|
||||||
|
assert str(exc_info.value) == "'vocalization'"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user