First draft for rails-like relationship accessing

This commit is contained in:
Pierre Dulac 2013-03-18 18:06:32 +01:00
parent 22a389d0ba
commit e9670b1ac1
3 changed files with 91 additions and 0 deletions

View File

@ -253,3 +253,24 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) return self.destroy(request, *args, **kwargs)
class RetrieveRelationshipAPIView(mixins.RetrieveRelationshipMixin,
SingleObjectAPIView):
"""
Rails-like relationship access
Eg. /api/album/1/tracks
to access the related field `tracks` of the album with pk=1
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
def get_serializer_class(self):
relationship = self.get_related_field_from_relationship()
class DefaultSerializer(self.model_serializer_class):
class Meta:
model = relationship.model
return DefaultSerializer

View File

@ -10,6 +10,8 @@ from django.http import Http404
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.relations import Relationship
from rest_framework.compat import get_concrete_model
def _get_validation_exclusions(obj, pk=None, slug_field=None): def _get_validation_exclusions(obj, pk=None, slug_field=None):
@ -168,3 +170,37 @@ class DestroyModelMixin(object):
obj = self.get_object() obj = self.get_object()
obj.delete() obj.delete()
return Response(status=status.HTTP_204_NO_CONTENT) return Response(status=status.HTTP_204_NO_CONTENT)
class RetrieveRelationshipMixin(object):
"""
Retrieve relationships
"""
def retrieve(self, request, *args, **kwargs):
relationship = self.get_related_field_from_relationship()
self.object = self.get_object()
self.related_object_or_list = getattr(self.object, relationship.field_name)
if relationship.to_many:
self.related_object_or_list = self.related_object_or_list.all()
serializer = self.get_serializer(self.related_object_or_list, many=relationship.to_many)
return Response(serializer.data)
def get_related_field_from_relationship(self):
"""
Return a `Relationship` instance for the related field specified in the request arguments
"""
relationship = self.kwargs.get('relationship', None)
relation = None
if relationship:
opts = get_concrete_model(self.model)._meta
fields = opts.get_all_field_names()
if relationship in fields:
field, model, direct, m2m = opts.get_field_by_name(relationship)
relation = Relationship(field)
if not relation:
raise Exception("No relationship found for '%s'" % relationship)
return relation

View File

@ -5,6 +5,8 @@ from django import forms
from django.forms import widgets from django.forms import widgets
from django.forms.models import ModelChoiceIterator from django.forms.models import ModelChoiceIterator
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.db.models.related import RelatedObject
from django.db.models.fields.related import ForeignKey
from rest_framework.fields import Field, WritableField, get_component from rest_framework.fields import Field, WritableField, get_component
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.compat import urlparse from rest_framework.compat import urlparse
@ -478,6 +480,38 @@ class HyperlinkedIdentityField(Field):
raise Exception('Could not resolve URL for field using view name "%s"' % view_name) raise Exception('Could not resolve URL for field using view name "%s"' % view_name)
class Relationship(object):
"""
Wrapper for a relationship
Encapsulate the `RelatedObject` & `ForeignKey` resolution.
"""
_field = None
def __init__(self, field):
if not isinstance(field, RelatedObject) and not isinstance(field, ForeignKey):
raise Exception("Unsupported type of relationship '%s'" % field)
self._field = field
@property
def to_many(self):
if isinstance(self._field, RelatedObject):
return True
return False
@property
def field_name(self):
if isinstance(self._field, RelatedObject):
return self._field.field.related_query_name()
return self._field.name
@property
def model(self):
if isinstance(self._field, RelatedObject):
return self._field.model
return self._field.related.parent_model
### Old-style many classes for backwards compat ### Old-style many classes for backwards compat
class ManyRelatedField(RelatedField): class ManyRelatedField(RelatedField):