diff --git a/rest_framework/generics.py b/rest_framework/generics.py index c39b02ab7..55cfafda4 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -106,7 +106,7 @@ class GenericAPIView(views.APIView): deserializing input, and for serializing output. """ serializer_class = self.get_serializer_class() - kwargs['context'] = self.get_serializer_context() + kwargs.setdefault('context', self.get_serializer_context()) return serializer_class(*args, **kwargs) def get_serializer_class(self): diff --git a/tests/test_generics.py b/tests/test_generics.py index 0b91e3465..2907d2773 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -662,3 +662,33 @@ class GetObjectOr404Tests(TestCase): def test_get_object_or_404_with_invalid_string_for_uuid(self): with pytest.raises(Http404): generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid') + + +class TestSerializer(TestCase): + + def test_serializer_class_not_provided(self): + class NoSerializerClass(generics.GenericAPIView): + pass + + with pytest.raises(AssertionError) as excinfo: + NoSerializerClass().get_serializer_class() + + assert str(excinfo.value) == ( + "'NoSerializerClass' should either include a `serializer_class` " + "attribute, or override the `get_serializer_class()` method.") + + def test_given_context_not_overridden(self): + context = object() + + class View(generics.ListAPIView): + serializer_class = serializers.Serializer + + def list(self, request): + response = Response() + response.serializer = self.get_serializer(context=context) + return response + + response = View.as_view()(factory.get('/')) + serializer = response.serializer + + assert serializer.context is context