From 89e1cae28e77de85daa0acb1b438515741eec41f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Piquemal?= Date: Mon, 9 Jan 2012 20:05:55 +0200 Subject: [PATCH] better Resource.get_model and Resource.get_queryset methods (directly inspired from generic.DetailView) --- djangorestframework/resources.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index fde91e305..192d134cb 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -1,6 +1,7 @@ from django import forms from django.core.urlresolvers import reverse, get_urlconf, get_resolver, NoReverseMatch from django.db import models +from django.core.exceptions import ImproperlyConfigured from djangorestframework.response import ErrorResponse from djangorestframework.serializer import Serializer, _SkipField @@ -529,18 +530,36 @@ class ModelResource(FormResource): def get_model(self): """ - Return the model class for this view. + Return the model class for this resource. """ - return getattr(self, 'model', getattr(self.view, 'model', None)) + model = getattr(self, 'model', None) + if model is None: + model = getattr(self.view, 'model', None) + if model is None: + raise ImproperlyConfigured(u"%(cls)s is missing a model. Define " + u"%(cls)s.model." % { + 'cls': self.__class__ + }) + return model def get_queryset(self): """ Return the queryset that should be used when retrieving or listing instances. """ - return getattr(self, 'queryset', - getattr(self.view, 'queryset', - self.get_model().objects.all())) + queryset = getattr(self, 'queryset', None) + if queryset is None: + queryset = getattr(self.view, 'queryset', None) + if queryset is None: + try: + model = self.get_model() + except ImproperlyConfigured: + raise ImproperlyConfigured(u"%(cls)s is missing a queryset. Define " + u"%(cls)s.model or %(cls)s.queryset." % { + 'cls': self.__class__ + }) + queryset = model._default_manager.all() + return queryset._clone() def get_ordering(self): """