diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 36ecf9150..3bf828917 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -253,3 +253,24 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, def delete(self, request, *args, **kwargs): return self.destroy(request, *args, **kwargs) + + +class RetrieveRelationshipAPIView(mixins.RetrieveRelationshipMixin, + SingleObjectAPIView): + """ + Rails-like relationship access + + Eg. /api/album/1/tracks + to access the related field `tracks` of the album with pk=1 + """ + def get(self, request, *args, **kwargs): + return self.retrieve(request, *args, **kwargs) + + def get_serializer_class(self): + relationship = self.get_related_field_from_relationship() + + class DefaultSerializer(self.model_serializer_class): + class Meta: + model = relationship.model + + return DefaultSerializer diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7d9a6e654..5b63aa6a8 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -10,6 +10,8 @@ from django.http import Http404 from rest_framework import status from rest_framework.response import Response from rest_framework.request import clone_request +from rest_framework.relations import Relationship +from rest_framework.compat import get_concrete_model def _get_validation_exclusions(obj, pk=None, slug_field=None): @@ -168,3 +170,37 @@ class DestroyModelMixin(object): obj = self.get_object() obj.delete() return Response(status=status.HTTP_204_NO_CONTENT) + + +class RetrieveRelationshipMixin(object): + """ + Retrieve relationships + """ + def retrieve(self, request, *args, **kwargs): + relationship = self.get_related_field_from_relationship() + self.object = self.get_object() + + self.related_object_or_list = getattr(self.object, relationship.field_name) + if relationship.to_many: + self.related_object_or_list = self.related_object_or_list.all() + + serializer = self.get_serializer(self.related_object_or_list, many=relationship.to_many) + return Response(serializer.data) + + def get_related_field_from_relationship(self): + """ + Return a `Relationship` instance for the related field specified in the request arguments + """ + relationship = self.kwargs.get('relationship', None) + relation = None + if relationship: + opts = get_concrete_model(self.model)._meta + fields = opts.get_all_field_names() + if relationship in fields: + field, model, direct, m2m = opts.get_field_by_name(relationship) + relation = Relationship(field) + + if not relation: + raise Exception("No relationship found for '%s'" % relationship) + + return relation diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 2a10e9af5..5d296010e 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -5,6 +5,8 @@ from django import forms from django.forms import widgets from django.forms.models import ModelChoiceIterator from django.utils.translation import ugettext_lazy as _ +from django.db.models.related import RelatedObject +from django.db.models.fields.related import ForeignKey from rest_framework.fields import Field, WritableField, get_component from rest_framework.reverse import reverse from rest_framework.compat import urlparse @@ -478,6 +480,38 @@ class HyperlinkedIdentityField(Field): raise Exception('Could not resolve URL for field using view name "%s"' % view_name) +class Relationship(object): + """ + Wrapper for a relationship + + Encapsulate the `RelatedObject` & `ForeignKey` resolution. + """ + _field = None + + def __init__(self, field): + if not isinstance(field, RelatedObject) and not isinstance(field, ForeignKey): + raise Exception("Unsupported type of relationship '%s'" % field) + self._field = field + + @property + def to_many(self): + if isinstance(self._field, RelatedObject): + return True + return False + + @property + def field_name(self): + if isinstance(self._field, RelatedObject): + return self._field.field.related_query_name() + return self._field.name + + @property + def model(self): + if isinstance(self._field, RelatedObject): + return self._field.model + return self._field.related.parent_model + + ### Old-style many classes for backwards compat class ManyRelatedField(RelatedField):