Allow custom serializer class for request and response

This commit is contained in:
khamaileon 2022-04-18 09:25:55 +02:00
parent f378f98a40
commit dab91c36c3
4 changed files with 144 additions and 2 deletions

View File

@ -295,6 +295,8 @@ As with `ModelViewSet`, you'll normally need to provide at least the `queryset`
Again, as with `ModelViewSet`, you can use any of the standard attributes and method overrides available to `GenericAPIView`.
---
# Custom ViewSet base classes
You may need to provide custom `ViewSet` classes that do not have the full set of `ModelViewSet` actions, or that customize the behavior in some other way.
@ -321,3 +323,44 @@ By creating your own base `ViewSet` classes, you can provide common behavior tha
[cite]: https://guides.rubyonrails.org/action_controller_overview.html
[routers]: routers.md
---
# Custom serializer for request and response
It is possible to define at the view level (or for each custom method via the @action decorator) a custom serialization class for each request and response.
To do this you need to define `request_serializer_response` and `response_serializer_response` and call them via `get_request_serializer` and `get_response_serializer`.
class UserViewSet(viewsets.ModelViewSet):
"""
A viewset that provides the standard actions
"""
queryset = User.objects.all()
serializer_class = UserSerializer
@action(
detail=True,
methods=['post'],
request_serializer_class=PasswordSerializer
)
def set_password(self, request, pk=None):
user = self.get_object()
serializer = self.get_request_serializer(data=request.data)
if serializer.is_valid():
user.set_password(serializer.validated_data['password'])
user.save()
return Response({'status': 'password set'})
else:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)
@action(
detail=True,
methods=['get'],
request_response_class=ExtendedUserSerializer
)
def complete_profile(self, request, pk=None):
user = self.get_object()
response_serializer = self.get_response_serializer(user)
return Response(response_serializer.data)

View File

@ -33,6 +33,8 @@ class GenericAPIView(views.APIView):
# for all subsequent requests.
queryset = None
serializer_class = None
request_serializer_class = None
response_serializer_class = None
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
@ -109,6 +111,23 @@ class GenericAPIView(views.APIView):
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)
def get_request_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for validating and
deserializing input.
"""
serializer_class = self.get_request_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)
def get_response_serializer(self, *args, **kwargs):
"""
Return the serializer instance that should be used for serializing output.
"""
serializer_class = self.get_response_serializer_class()
kwargs.setdefault('context', self.get_serializer_context())
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
"""
Return the class to use for the serializer.
@ -127,6 +146,18 @@ class GenericAPIView(views.APIView):
return self.serializer_class
def get_request_serializer_class(self):
"""
Return the class to use as input serializer.
"""
return self.request_serializer_class or self.get_serializer_class()
def get_response_serializer_class(self):
"""
Returns the class to use as output serializer.
"""
return self.response_serializer_class or self.get_serializer_class()
def get_serializer_context(self):
"""
Extra context provided to the serializer class.

View File

@ -627,15 +627,41 @@ class AutoSchema(ViewInspector):
Override this method if your view uses a different serializer for
handling request body.
"""
view = self.view
if not hasattr(view, "get_request_serializer"):
return self.get_serializer(path, method)
try:
return view.get_request_serializer()
except exceptions.APIException:
warnings.warn(
"{}.get_request_serializer() raised an exception during "
"schema generation. Serializer fields will not be "
"generated for {} {}.".format(view.__class__.__name__, method, path)
)
return None
def get_response_serializer(self, path, method):
"""
Override this method if your view uses a different serializer for
populating response data.
"""
view = self.view
if not hasattr(view, "get_response_serializer"):
return self.get_serializer(path, method)
try:
return view.get_response_serializer()
except exceptions.APIException:
warnings.warn(
"{}.get_response_serializer() raised an exception during "
"schema generation. Serializer fields will not be "
"generated for {} {}.".format(view.__class__.__name__, method, path)
)
return None
def _get_reference(self, serializer):
return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))}

View File

@ -692,3 +692,45 @@ class TestSerializer(TestCase):
serializer = response.serializer
assert serializer.context is context
def test_get_request_serializer_class(self):
class View(generics.GenericAPIView):
request_serializer_class = BasicSerializer
view = View()
assert view.get_request_serializer_class() == BasicSerializer
def test_get_response_serializer_class(self):
class TestResponseSerializerView(generics.GenericAPIView):
response_serializer_class = BasicSerializer
view = TestResponseSerializerView()
assert view.get_response_serializer_class() == BasicSerializer
def test_get_request_serializer(self):
class View(generics.ListAPIView):
request_serializer_class = BasicSerializer
def list(self, request):
response = Response()
response.serializer = self.get_request_serializer()
return response
view = View.as_view()
request = factory.get('/')
response = view(request)
assert isinstance(response.serializer, BasicSerializer)
def test_get_response_serializer(self):
class View(generics.ListAPIView):
response_serializer_class = BasicSerializer
def list(self, request):
response = Response()
response.serializer = self.get_response_serializer()
return response
view = View.as_view()
request = factory.get('/')
response = view(request)
assert isinstance(response.serializer, BasicSerializer)