diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fdfba13f2..f2dd9d74d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -283,14 +283,25 @@ class CreateOnlyDefault: return '%s(%s)' % (self.__class__.__name__, repr(self.default)) -class CurrentUserDefault: +class ValueFromContext: requires_context = True + def __init__(self, context_name): + self.context_name = context_name + def __call__(self, serializer_field): - return serializer_field.context['request'].user + return serializer_field.context[self.context_name] def __repr__(self): - return '%s()' % self.__class__.__name__ + return "%s()" % self.__class__.__name__ + + +class CurrentUserDefault(ValueFromContext): + def __init__(self, *args, **kwargs): + super().__init__(context_name="request") + + def __call__(self, serializer_field): + return super().__call__(serializer_field).user class SkipField(Exception): diff --git a/tests/test_fields.py b/tests/test_fields.py index a83c0bd68..439d3a92c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -15,7 +15,8 @@ from django.utils.timezone import activate, deactivate, override, utc import rest_framework from rest_framework import exceptions, serializers from rest_framework.fields import ( - BuiltinSignatureError, CurrentUserDefault, DjangoImageField, is_simple_callable + BuiltinSignatureError, CurrentUserDefault, DjangoImageField, + ValueFromContext, is_simple_callable ) # Tests for helper functions. @@ -2407,3 +2408,27 @@ class TestCurrentUserDefault: with pytest.raises(KeyError) as exc_info: serializer.is_valid() assert str(exc_info.value) == "'request'" + + +class ValueFromContextSerializer(serializers.Serializer): + vocalization = serializers.HiddenField(default=ValueFromContext("vocalization")) + + +class TestValueFromContext: + def test_context_set(self): + serializer = ValueFromContextSerializer(data={}, context={"vocalization": "meow"}) + serializer.is_valid() + field = serializer.fields["vocalization"] + assert field.get_default() == "meow" + + def test_context_set_none(self): + serializer = ValueFromContextSerializer(data={}, context={"vocalization": None}) + serializer.is_valid() + field = serializer.fields["vocalization"] + assert field.get_default() is None + + def test_missing_context(self): + serializer = ValueFromContextSerializer(data={}, context={}) + with pytest.raises(KeyError) as exc_info: + serializer.is_valid() + assert str(exc_info.value) == "'vocalization'"