diff --git a/rest_framework/viewsets.py b/rest_framework/viewsets.py index a9c90a8d9..7d66dd494 100644 --- a/rest_framework/viewsets.py +++ b/rest_framework/viewsets.py @@ -27,6 +27,7 @@ from django.views.decorators.csrf import csrf_exempt from rest_framework import generics, mixins, views from rest_framework.decorators import MethodMapper from rest_framework.reverse import reverse +from rest_framework.renderers import TemplateHTMLRenderer,JSONRenderer def _is_extra_action(attr): @@ -253,3 +254,55 @@ class ModelViewSet(mixins.CreateModelMixin, `partial_update()`, `destroy()` and `list()` actions. """ pass + + + +class UsefulViewSet(ModelViewSet): + renderer_classes = [TemplateHTMLRenderer,JSONRenderer] + # serializers + list_serializer = None + create_serializer = None + update_serializer = None + retrieve_serializer = None + # templates + update_template = None + list_template = None + + def get_list_extra_context(self,request,*args, **kwargs): + return {} + + def get_retrieve_extra_context(self,request,*args, **kwargs): + return {} + + def _get_template_name(self,template_name): + return template_name if template_name else self.template_name + + def list(self, request, *args, **kwargs): + self.template_name = self._get_template_name(self.list_template) + response = super().list(request, *args, **kwargs) + response.data.update(self.get_list_extra_context(request, *args, **kwargs)) + return response + + def retrieve(self, request, *args, **kwargs): + self.template_name = self._get_template_name(self.update_template) + response = super().retrieve(request, *args, **kwargs) + response.data.update(self.get_retrieve_extra_context(request, *args, **kwargs)) + return response + + + def _get_serializer_class(self,serializer): + return serializer if serializer else self.serializer_class + + def get_serializer_class(self): + if self.action == "list": + return self._get_serializer_class(self.list_serializer) + elif self.action == "create": + return self._get_serializer_class(self.create_serializer) + elif self.action in ["retrieve",]: + return self._get_serializer_class(self.retrieve_serializer) + elif self.action == 'update': + return self._get_serializer_class(self.update_serializer) + return + + +