From 325e63a3a767bf4aedef7be616cc268a08537424 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 13 May 2011 17:19:12 +0100 Subject: [PATCH] Sorting out resources. Doing some crazy magic automatic url resolving stuff. Yum. --- djangorestframework/mixins.py | 65 +++- djangorestframework/resources.py | 466 +++++++----------------- djangorestframework/tests/resources.py | 2 +- djangorestframework/utils/mediatypes.py | 1 + djangorestframework/views.py | 14 +- 5 files changed, 200 insertions(+), 348 deletions(-) diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 12f2d779a..70ec677ea 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -25,6 +25,8 @@ __all__ = ( 'ResponseMixin', 'AuthMixin', 'ResourceMixin', + # + 'InstanceMixin', # Model behavior mixins 'ReadModelMixin', 'CreateModelMixin', @@ -137,7 +139,7 @@ class RequestMixin(object): content_length = 0 # TODO: Add 1.3's LimitedStream to compat and use that. - # Currently only supports parsing request body as a stream with 1.3 + # NOTE: Currently only supports parsing request body as a stream with 1.3 if content_length == 0: return None elif hasattr(request, 'read'): @@ -379,8 +381,8 @@ class AuthMixin(object): if not hasattr(self, '_user'): self._user = self._authenticate() return self._user - - + + def _authenticate(self): """ Attempt to authenticate the request using each authentication class in turn. @@ -405,26 +407,71 @@ class AuthMixin(object): permission.check_permission(user) +########## + +class InstanceMixin(object): + """ + Mixin class that is used to identify a view class as being the canonical identifier + for the resources it is mapped too. + """ + + @classmethod + def as_view(cls, **initkwargs): + """ + Store the callable object on the resource class that has been associated with this view. + """ + view = super(InstanceMixin, cls).as_view(**initkwargs) + if 'resource' in initkwargs: + # We do a little dance when we store the view callable... + # we need to store it wrapped in a 1-tuple, so that inspect will treat it + # as a function when we later look it up (rather than turning it into a method). + # This makes sure our URL reversing works ok. + initkwargs['resource'].view_callable = (view,) + return view + ########## Resource Mixin ########## class ResourceMixin(object): + """ + Provides request validation and response filtering behavior. + """ + + """ + Should be a class as described in the ``resources`` module. + + The ``resource`` is an object that maps a view onto it's representation on the server. + + It provides validation on the content of incoming requests, + and filters the object representation into a serializable object for the response. + """ + resource = None + @property def CONTENT(self): if not hasattr(self, '_content'): - self._content = self._get_content() + self._content = self.validate_request(self.DATA, self.FILES) return self._content - def _get_content(self): + def validate_request(self, data, files): + """ + Given the request data return the cleaned, validated content. + Typically raises a ErrorResponse with status code 400 (Bad Request) on failure. + """ resource = self.resource(self) - return resource.validate(self.DATA, self.FILES) + return resource.validate_request(data, files) + + def filter_response(self, obj): + """ + Given the response content, filter it into a serializable object. + """ + resource = self.resource(self) + return resource.filter_response(obj) def get_bound_form(self, content=None): resource = self.resource(self) return resource.get_bound_form(content) - def object_to_data(self, obj): - resource = self.resource(self) - return resource.object_to_data(obj) + ########## Model Mixins ########## diff --git a/djangorestframework/resources.py b/djangorestframework/resources.py index f47b41d0e..31b9b0141 100644 --- a/djangorestframework/resources.py +++ b/djangorestframework/resources.py @@ -1,3 +1,5 @@ +from django import forms +from django.core.urlresolvers import reverse, get_urlconf, get_resolver, NoReverseMatch from django.db import models from django.db.models.query import QuerySet from django.db.models.fields.related import RelatedField @@ -9,10 +11,15 @@ import re -def _model_to_dict(instance, fields=None, exclude=None): +def _model_to_dict(instance, resource=None): """ - This is a clone of Django's ``django.forms.model_to_dict`` except that it - doesn't coerce related objects into primary keys. + Given a model instance, return a ``dict`` representing the model. + + The implementation is similar to Django's ``django.forms.model_to_dict``, except: + + * It doesn't coerce related objects into primary keys. + * It doesn't drop ``editable=False`` fields. + * It also supports attribute or method fields on the instance or resource. """ opts = instance._meta data = {} @@ -20,10 +27,19 @@ def _model_to_dict(instance, fields=None, exclude=None): #print [rel.name for rel in opts.get_all_related_objects()] #related = [rel.get_accessor_name() for rel in opts.get_all_related_objects()] #print [getattr(instance, rel) for rel in related] + #if resource.fields: + # fields = resource.fields + #else: + # fields = set(opts.fields + opts.many_to_many) + + fields = resource.fields + include = resource.include + exclude = resource.exclude + extra_fields = fields and list(resource.fields) or [] + + # Model fields for f in opts.fields + opts.many_to_many: - #if not f.editable: - # continue if fields and not f.name in fields: continue if exclude and f.name in exclude: @@ -32,87 +48,84 @@ def _model_to_dict(instance, fields=None, exclude=None): data[f.name] = getattr(instance, f.name) else: data[f.name] = f.value_from_object(instance) - - #print fields - (opts.fields + opts.many_to_many) - #for related in [rel.get_accessor_name() for rel in opts.get_all_related_objects()]: - # if fields and not related in fields: - # continue - # if exclude and related in exclude: - # continue - # data[related] = getattr(instance, related) + + if extra_fields and f.name in extra_fields: + extra_fields.remove(f.name) + # Method fields + for fname in extra_fields: + if hasattr(resource, fname): + # check the resource first, to allow it to override fields + obj = getattr(resource, fname) + # if it's a method like foo(self, instance), then call it + if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) == 2: + obj = obj(instance) + elif hasattr(instance, fname): + # now check the object instance + obj = getattr(instance, fname) + else: + continue + + # TODO: It would be nicer if this didn't recurse here. + # Let's keep _model_to_dict flat, and _object_to_data recursive. + data[fname] = _object_to_data(obj) + return data -def _object_to_data(obj): +def _object_to_data(obj, resource=None): """ Convert an object into a serializable representation. """ if isinstance(obj, dict): # dictionaries - return dict([ (key, _object_to_data(val)) for key, val in obj.iteritems() ]) + # TODO: apply same _model_to_dict logic fields/exclude here + return dict([ (key, _object_to_data(val)) for key, val in obj.iteritems() ]) if isinstance(obj, (tuple, list, set, QuerySet)): # basic iterables - return [_object_to_data(item) for item in obj] + return [_object_to_data(item, resource) for item in obj] if isinstance(obj, models.Manager): # Manager objects - return [_object_to_data(item) for item in obj.all()] + return [_object_to_data(item, resource) for item in obj.all()] if isinstance(obj, models.Model): # Model instances - return _object_to_data(_model_to_dict(obj)) + return _object_to_data(_model_to_dict(obj, resource)) if isinstance(obj, decimal.Decimal): # Decimals (force to string representation) return str(obj) if inspect.isfunction(obj) and not inspect.getargspec(obj)[0]: # function with no args - return _object_to_data(obj()) - if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) == 1: - # method with only a 'self' args - return _object_to_data(obj()) + return _object_to_data(obj(), resource) + if inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1: + # bound method + return _object_to_data(obj(), resource) - # fallback return smart_unicode(obj, strings_only=True) -def _form_to_data(form): - """ - Returns a dict containing the data in a form instance. - - This code is pretty much a clone of the ``Form.as_p()`` ``Form.as_ul`` - and ``Form.as_table()`` methods, except that it returns data suitable - for arbitrary serialization, rather than rendering the result directly - into html. - """ - ret = {} - for name, field in form.fields.items(): - if not form.is_bound: - data = form.initial.get(name, field.initial) - if callable(data): - data = data() - else: - if isinstance(field, FileField) and form.data is None: - data = form.initial.get(name, field.initial) - else: - data = field.widget.value_from_datadict(form.data, form.files, name) - ret[name] = field.prepare_value(data) - return ret - - class BaseResource(object): - """Base class for all Resource classes, which simply defines the interface they provide.""" + """ + Base class for all Resource classes, which simply defines the interface they provide. + """ + fields = None + include = None + exclude = None def __init__(self, view): self.view = view - def validate(self, data, files): - """Given some content as input return some cleaned, validated content. + def validate_request(self, data, files): + """ + Given the request data return the cleaned, validated content. Typically raises a ErrorResponse with status code 400 (Bad Request) on failure. - - Must be overridden to be implemented.""" + """ return data - def object_to_data(self, obj): - return _object_to_data(obj) + def filter_response(self, obj): + """ + Given the response content, filter it into a serializable object. + """ + return _object_to_data(obj, self) class Resource(BaseResource): @@ -135,247 +148,18 @@ class Resource(BaseResource): # you should explicitly set the fields attribute on your class. fields = None - # TODO: Replace this with new Serializer code based on Forms API. - def object_to_data(self, obj): - """ - A (horrible) munging of Piston's pre-serialization. Returns a dict. - """ - - return _object_to_data(obj) - - def _any(thing, fields=()): - """ - Dispatch, all types are routed through here. - """ - ret = None - - if isinstance(thing, QuerySet): - ret = _qs(thing, fields=fields) - elif isinstance(thing, (tuple, list)): - ret = _list(thing) - elif isinstance(thing, dict): - ret = _dict(thing) - elif isinstance(thing, int): - ret = thing - elif isinstance(thing, bool): - ret = thing - elif isinstance(thing, type(None)): - ret = thing - elif isinstance(thing, decimal.Decimal): - ret = str(thing) - elif isinstance(thing, models.Model): - ret = _model(thing, fields=fields) - #elif isinstance(thing, HttpResponse): TRC - # raise HttpStatusCode(thing) - elif inspect.isfunction(thing): - if not inspect.getargspec(thing)[0]: - ret = _any(thing()) - elif hasattr(thing, '__rendertable__'): - f = thing.__rendertable__ - if inspect.ismethod(f) and len(inspect.getargspec(f)[0]) == 1: - ret = _any(f()) - else: - ret = unicode(thing) # TRC - - return ret - - def _fk(data, field): - """ - Foreign keys. - """ - return _any(getattr(data, field.name)) - - def _related(data, fields=()): - """ - Foreign keys. - """ - return [ _model(m, fields) for m in data.iterator() ] - - def _m2m(data, field, fields=()): - """ - Many to many (re-route to `_model`.) - """ - return [ _model(m, fields) for m in getattr(data, field.name).iterator() ] - - - def _method_fields(data, fields): - if not data: - return { } - - has = dir(data) - ret = dict() - - for field in fields: - if field in has: - ret[field] = getattr(data, field) - - return ret - - def _model(data, fields=()): - """ - Models. Will respect the `fields` and/or - `exclude` on the handler (see `typemapper`.) - """ - ret = { } - #handler = self.in_typemapper(type(data), self.anonymous) # TRC - handler = None # TRC - get_absolute_url = False - - if fields: - v = lambda f: getattr(data, f.attname) - - get_fields = set(fields) - if 'absolute_url' in get_fields: # MOVED (TRC) - get_absolute_url = True - - met_fields = _method_fields(handler, get_fields) # TRC - - for f in data._meta.local_fields: - if f.serialize and not any([ p in met_fields for p in [ f.attname, f.name ]]): - if not f.rel: - if f.attname in get_fields: - ret[f.attname] = _any(v(f)) - get_fields.remove(f.attname) - else: - if f.attname[:-3] in get_fields: - ret[f.name] = _fk(data, f) - get_fields.remove(f.name) - - for mf in data._meta.many_to_many: - if mf.serialize and mf.attname not in met_fields: - if mf.attname in get_fields: - ret[mf.name] = _m2m(data, mf) - get_fields.remove(mf.name) - - # try to get the remainder of fields - for maybe_field in get_fields: - - if isinstance(maybe_field, (list, tuple)): - model, fields = maybe_field - inst = getattr(data, model, None) - - if inst: - if hasattr(inst, 'all'): - ret[model] = _related(inst, fields) - elif callable(inst): - if len(inspect.getargspec(inst)[0]) == 1: - ret[model] = _any(inst(), fields) - else: - ret[model] = _model(inst, fields) - - elif maybe_field in met_fields: - # Overriding normal field which has a "resource method" - # so you can alter the contents of certain fields without - # using different names. - ret[maybe_field] = _any(met_fields[maybe_field](data)) - - else: - maybe = getattr(data, maybe_field, None) - if maybe: - if callable(maybe): - if len(inspect.getargspec(maybe)[0]) == 1: - ret[maybe_field] = _any(maybe()) - else: - ret[maybe_field] = _any(maybe) - else: - pass # TRC - #handler_f = getattr(handler or self.handler, maybe_field, None) - # - #if handler_f: - # ret[maybe_field] = _any(handler_f(data)) - - else: - # Add absolute_url if it exists - get_absolute_url = True - - # Add all the fields - for f in data._meta.fields: - if f.attname != 'id': - ret[f.attname] = _any(getattr(data, f.attname)) - - # Add all the propertiess - klass = data.__class__ - for attr in dir(klass): - if not attr.startswith('_') and not attr in ('pk','id') and isinstance(getattr(klass, attr, None), property): - #if attr.endswith('_url') or attr.endswith('_uri'): - # ret[attr] = self.make_absolute(_any(getattr(data, attr))) - #else: - ret[attr] = _any(getattr(data, attr)) - #fields = dir(data.__class__) + ret.keys() - #add_ons = [k for k in dir(data) if k not in fields and not k.startswith('_')] - #print add_ons - ###print dir(data.__class__) - #from django.db.models import Model - #model_fields = dir(Model) - - #for attr in dir(data): - ## #if attr.startswith('_'): - ## # continue - # if (attr in fields) and not (attr in model_fields) and not attr.startswith('_'): - # print attr, type(getattr(data, attr, None)), attr in fields, attr in model_fields - - #for k in add_ons: - # ret[k] = _any(getattr(data, k)) - - # TRC - # resouce uri - #if self.in_typemapper(type(data), self.anonymous): - # handler = self.in_typemapper(type(data), self.anonymous) - # if hasattr(handler, 'resource_uri'): - # url_id, fields = handler.resource_uri() - # ret['resource_uri'] = permalink( lambda: (url_id, - # (getattr(data, f) for f in fields) ) )() - - # TRC - #if hasattr(data, 'get_api_url') and 'resource_uri' not in ret: - # try: ret['resource_uri'] = data.get_api_url() - # except: pass - - # absolute uri - if hasattr(data, 'get_absolute_url') and get_absolute_url: - try: ret['absolute_url'] = data.get_absolute_url() - except: pass - - #for key, val in ret.items(): - # if key.endswith('_url') or key.endswith('_uri'): - # ret[key] = self.add_domain(val) - - return ret - - def _qs(data, fields=()): - """ - Querysets. - """ - return [ _any(v, fields) for v in data ] - - def _list(data): - """ - Lists. - """ - return [ _any(v) for v in data ] - - def _dict(data): - """ - Dictionaries. - """ - return dict([ (k, _any(v)) for k, v in data.iteritems() ]) - - # Kickstart the seralizin'. - return _any(obj, self.fields) - class FormResource(Resource): - """Validator class that uses forms for validation. + """ + Resource class that uses forms for validation. Also provides a get_bound_form() method which may be used by some renderers. - - The view class should provide `.form` attribute which specifies the form classmethod - to be used for validation. - + On calling validate() this validator may set a `.bound_form_instance` attribute on the - view, which may be used by some renderers.""" + view, which may be used by some renderers. + """ + form = None - - def validate(self, data, files): + def validate_request(self, data, files): """ Given some content as input return some cleaned, validated content. Raises a ErrorResponse with status code 400 (Bad Request) on failure. @@ -434,10 +218,12 @@ class FormResource(Resource): detail[u'errors'] = bound_form.non_field_errors() # Add standard field errors - field_errors = dict((key, map(unicode, val)) + field_errors = dict( + (key, map(unicode, val)) for (key, val) in bound_form.errors.iteritems() - if not key.startswith('__')) + if not key.startswith('__') + ) # Add any unknown field errors for key in unknown_fields: @@ -451,22 +237,24 @@ class FormResource(Resource): def get_bound_form(self, data=None, files=None): - """Given some content return a Django form bound to that content. - If form validation is turned off (form class attribute is None) then returns None.""" - form_cls = getattr(self, 'form', None) - - if not form_cls: + """ + Given some content return a Django form bound to that content. + If form validation is turned off (form class attribute is None) then returns None. + """ + if not self.form: return None if data is not None: - return form_cls(data, files) + return self.form(data, files) - return form_cls() + return self.form() class ModelResource(FormResource): - """Validator class that uses forms for validation and otherwise falls back to a model form if no form is set. - Also provides a get_bound_form() method which may be used by some renderers.""" + """ + Resource class that uses forms for validation and otherwise falls back to a model form if no form is set. + Also provides a get_bound_form() method which may be used by some renderers. + """ """The form class that should be used for validation, or None to use model form validation.""" form = None @@ -477,16 +265,16 @@ class ModelResource(FormResource): """The list of fields we expect to receive as input. Fields in this list will may be received with raising non-existent field errors, even if they do not exist as fields on the ModelForm. - Setting the fields class attribute causes the exclude_fields class attribute to be disregarded.""" + Setting the fields class attribute causes the exclude class attribute to be disregarded.""" fields = None """The list of fields to exclude from the Model. This is only used if the fields class attribute is not set.""" - exclude_fields = ('id', 'pk') + exclude = ('id', 'pk') # TODO: test the different validation here to allow for get get_absolute_url to be supplied on input and not bork out # TODO: be really strict on fields - check they match in the handler methods. (this isn't a validator thing tho.) - def validate(self, data, files): + def validate_request(self, data, files): """ Given some content as input return some cleaned, validated content. Raises a ErrorResponse with status code 400 (Bad Request) on failure. @@ -503,66 +291,80 @@ class ModelResource(FormResource): return self._validate(data, files, allowed_extra_fields=self._property_fields_set) - def get_bound_form(self, data=None, files=None): + def get_bound_form(self, content=None): """Given some content return a Django form bound to that content. If the form class attribute has been explicitly set then use that class to create a Form, otherwise if model is set use that class to create a ModelForm, otherwise return None.""" - form_cls = getattr(self, 'form', None) - model_cls = getattr(self, 'model', None) - - if form_cls: + if self.form: # Use explict Form return super(ModelFormValidator, self).get_bound_form(data, files) - elif model_cls: + elif self.model: # Fall back to ModelForm which we create on the fly class OnTheFlyModelForm(forms.ModelForm): class Meta: - model = model_cls + model = self.model #fields = tuple(self._model_fields_set) # Instantiate the ModelForm as appropriate if content and isinstance(content, models.Model): # Bound to an existing model instance return OnTheFlyModelForm(instance=content) - elif not data is None: - return OnTheFlyModelForm(data, files) + elif content is not None: + return OnTheFlyModelForm(content) return OnTheFlyModelForm() # Both form and model not set? Okay bruv, whatevs... return None + def url(self, instance): + """ + Attempts to reverse resolve the url of the given model instance for this resource. + """ + + # dis does teh magicks... + urlconf = get_urlconf() + resolver = get_resolver(urlconf) + + possibilities = resolver.reverse_dict.getlist(self.view_callable[0]) + for tuple_item in possibilities: + possibility = tuple_item[0] + # pattern = tuple_item[1] + # Note: defaults = tuple_item[2] for django >= 1.3 + for result, params in possibility: + instance_attrs = dict([ (param, getattr(instance, param)) for param in params if hasattr(instance, param) ]) + try: + return reverse(self.view_callable[0], kwargs=instance_attrs) + except NoReverseMatch: + pass + raise NoReverseMatch + + @property def _model_fields_set(self): - """Return a set containing the names of validated fields on the model.""" - resource = self.view.resource - model = getattr(resource, 'model', None) - fields = getattr(resource, 'fields', self.fields) - exclude_fields = getattr(resource, 'exclude_fields', self.exclude_fields) - - model_fields = set(field.name for field in model._meta.fields) + """ + Return a set containing the names of validated fields on the model. + """ + model_fields = set(field.name for field in self.model._meta.fields) if fields: - return model_fields & set(as_tuple(fields)) + return model_fields & set(as_tuple(self.fields)) - return model_fields - set(as_tuple(exclude_fields)) + return model_fields - set(as_tuple(self.exclude)) @property def _property_fields_set(self): - """Returns a set containing the names of validated properties on the model.""" - resource = self.view.resource - model = getattr(resource, 'model', None) - fields = getattr(resource, 'fields', self.fields) - exclude_fields = getattr(resource, 'exclude_fields', self.exclude_fields) - - property_fields = set(attr for attr in dir(model) if - isinstance(getattr(model, attr, None), property) + """ + Returns a set containing the names of validated properties on the model. + """ + property_fields = set(attr for attr in dir(self.model) if + isinstance(getattr(self.model, attr, None), property) and not attr.startswith('_')) if fields: - return property_fields & set(as_tuple(fields)) + return property_fields & set(as_tuple(self.fields)) - return property_fields - set(as_tuple(exclude_fields)) + return property_fields - set(as_tuple(self.exclude)) diff --git a/djangorestframework/tests/resources.py b/djangorestframework/tests/resources.py index 6aa569d34..fd1226be3 100644 --- a/djangorestframework/tests/resources.py +++ b/djangorestframework/tests/resources.py @@ -11,7 +11,7 @@ class TestObjectToData(TestCase): def test_decimal(self): """Decimals need to be converted to a string representation.""" self.assertEquals(_object_to_data(decimal.Decimal('1.5')), '1.5') - + def test_function(self): """Functions with no arguments should be called.""" def foo(): diff --git a/djangorestframework/utils/mediatypes.py b/djangorestframework/utils/mediatypes.py index 62a5e6f36..190cdc2df 100644 --- a/djangorestframework/utils/mediatypes.py +++ b/djangorestframework/utils/mediatypes.py @@ -43,6 +43,7 @@ def add_media_type_param(media_type, key, val): media_type.params[key] = val return str(media_type) + def get_media_type_params(media_type): """ Return a dictionary of the parameters on the given media type. diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 315c25a93..2e7e8418a 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -18,8 +18,10 @@ __all__ = ( class BaseView(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, View): - """Handles incoming requests and maps them to REST operations. - Performs request deserialization, response serialization, authentication and input validation.""" + """ + Handles incoming requests and maps them to REST operations. + Performs request deserialization, response serialization, authentication and input validation. + """ # Use the base resource by default resource = resources.Resource @@ -77,8 +79,8 @@ class BaseView(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, View): prefix = '%s://%s' % (request.is_secure() and 'https' or 'http', request.get_host()) set_script_prefix(prefix) - try: - # Authenticate and check request is has the relevant permissions + try: + # Authenticate and check request has the relevant permissions self._check_permissions() # Get the appropriate handler method @@ -98,7 +100,7 @@ class BaseView(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, View): response = Response(status.HTTP_204_NO_CONTENT) # Pre-serialize filtering (eg filter complex objects into natively serializable types) - response.cleaned_content = self.object_to_data(response.raw_content) + response.cleaned_content = self.filter_response(response.raw_content) except ErrorResponse, exc: response = exc.response @@ -118,7 +120,7 @@ class ModelView(BaseView): """A RESTful view that maps to a model in the database.""" resource = resources.ModelResource -class InstanceModelView(ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView): +class InstanceModelView(InstanceMixin, ReadModelMixin, UpdateModelMixin, DeleteModelMixin, ModelView): """A view which provides default operations for read/update/delete against a model instance.""" pass