creating a method to return custom response data for generic views

This commit is contained in:
Ekluv 2017-03-12 00:41:53 +05:30
parent 1d34bc0b92
commit 39c7755a08

View File

@ -11,7 +11,18 @@ from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
class CreateModelMixin(object): class BaseModelMixin(object):
def get_response_data(self, data, *args, **kwargs):
"""
Get the data to return as response data
By default it returns `serializer.data`
Override this method to return custom response data
"""
return data
class CreateModelMixin(BaseModelMixin):
""" """
Create a model instance. Create a model instance.
""" """
@ -20,7 +31,8 @@ class CreateModelMixin(object):
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.perform_create(serializer) self.perform_create(serializer)
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) response_data = self.get_response_data(serializer.data)
return Response(response_data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer): def perform_create(self, serializer):
serializer.save() serializer.save()
@ -32,7 +44,7 @@ class CreateModelMixin(object):
return {} return {}
class ListModelMixin(object): class ListModelMixin(BaseModelMixin):
""" """
List a queryset. List a queryset.
""" """
@ -45,20 +57,22 @@ class ListModelMixin(object):
return self.get_paginated_response(serializer.data) return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True) serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data) response_data = self.get_response_data(serializer.data)
return Response(response_data)
class RetrieveModelMixin(object): class RetrieveModelMixin(BaseModelMixin):
""" """
Retrieve a model instance. Retrieve a model instance.
""" """
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
instance = self.get_object() instance = self.get_object()
serializer = self.get_serializer(instance) serializer = self.get_serializer(instance)
return Response(serializer.data) response_data = self.get_response_data(serializer.data)
return Response(response_data)
class UpdateModelMixin(object): class UpdateModelMixin(BaseModelMixin):
""" """
Update a model instance. Update a model instance.
""" """
@ -74,7 +88,8 @@ class UpdateModelMixin(object):
# forcibly invalidate the prefetch cache on the instance. # forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {} instance._prefetched_objects_cache = {}
return Response(serializer.data) response_data = self.get_response_data(serializer.data)
return Response(response_data)
def perform_update(self, serializer): def perform_update(self, serializer):
serializer.save() serializer.save()