Merge branch 'experimental' of https://github.com/sebpiq/django-rest-framework into experimental

This commit is contained in:
Tom Christie 2012-01-05 18:02:52 +00:00
commit e341d0d381
10 changed files with 438 additions and 360 deletions

View File

@ -8,7 +8,6 @@ from django.core.paginator import Paginator
from django.http import HttpResponse from django.http import HttpResponse
from djangorestframework import status from djangorestframework import status
from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import MSIE_USER_AGENT_REGEX from djangorestframework.utils import MSIE_USER_AGENT_REGEX
@ -26,7 +25,11 @@ __all__ = (
# Reverse URL lookup behavior # Reverse URL lookup behavior
'InstanceMixin', 'InstanceMixin',
# Model behavior mixins # Model behavior mixins
'ModelMixin', 'GetResourceMixin',
'PostResourceMixin',
'PutResourceMixin',
'DeleteResourceMixin',
'ListResourceMixin',
) )
@ -405,7 +408,7 @@ class ResourceMixin(object):
and filters the object representation into a serializable object for the and filters the object representation into a serializable object for the
response. response.
""" """
resource = None resource_class = None
@property @property
def CONTENT(self): def CONTENT(self):
@ -429,17 +432,24 @@ class ResourceMixin(object):
""" """
return self.validate_request(self.request.GET) return self.validate_request(self.request.GET)
@property def get_resource_class(self):
def _resource(self): if self.resource_class:
if self.resource: return self.resource_class
return self.resource(self)
elif getattr(self, 'model', None): elif getattr(self, 'model', None):
return ModelResource(self) return ModelResource
elif getattr(self, 'form', None): elif getattr(self, 'form', None):
return FormResource(self) return FormResource
elif getattr(self, '%s_form' % self.method.lower(), None): elif hasattr(self, 'request') and getattr(self, '%s_form' % self.method.lower(), None):
return FormResource(self) return FormResource
return Resource(self) else:
return Resource
@property
def resource(self):
if not hasattr(self, '_resource'):
resource_class = self.get_resource_class()
self._resource = resource_class(view=self)
return self._resource
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
@ -448,17 +458,17 @@ class ResourceMixin(object):
May raise an :class:`response.ErrorResponse` with status code 400 May raise an :class:`response.ErrorResponse` with status code 400
(Bad Request) on failure. (Bad Request) on failure.
""" """
return self._resource.validate_request(data, files) return self.resource.validate_request(data, files)
def filter_response(self, obj): def filter_response(self, obj):
""" """
Given the response content, filter it into a serializable object. Given the response content, filter it into a serializable object.
""" """
return self._resource.filter_response(obj) return self.resource.filter_response(obj)
def get_bound_form(self, content=None, method=None): def get_bound_form(self, content=None, method=None):
if hasattr(self._resource, 'get_bound_form'): if hasattr(self.resource, 'get_bound_form'):
return self._resource.get_bound_form(content, method=method) return self.resource.get_bound_form(content, method=method)
else: else:
return None return None
@ -479,161 +489,67 @@ class InstanceMixin(object):
associated with this view. associated with this view.
""" """
view = super(InstanceMixin, cls).as_view(**initkwargs) view = super(InstanceMixin, cls).as_view(**initkwargs)
resource = getattr(cls(**initkwargs), 'resource', None) resource_class = getattr(cls(**initkwargs), 'resource_class', None)
if resource: if resource_class:
# We do a little dance when we store the view callable... # We do a little dance when we store the view callable...
# we need to store it wrapped in a 1-tuple, so that inspect will # we need to store it wrapped in a 1-tuple, so that inspect will
# treat it as a function when we later look it up (rather than # treat it as a function when we later look it up (rather than
# turning it into a method). # turning it into a method).
# This makes sure our URL reversing works ok. # This makes sure our URL reversing works ok.
resource.view_callable = (view,) resource_class.view_callable = (view,)
return view return view
########## Model Mixins ########## ########## Resource operation Mixins ##########
class GetResourceMixin(object):
class ModelMixin(object): def get(self, request, *args, **kwargs):
def get_model(self):
"""
Return the model class for this view.
"""
return getattr(self, 'model', self.resource.model)
def get_queryset(self):
"""
Return the queryset that should be used when retrieving or listing
instances.
"""
return getattr(self, 'queryset',
getattr(self.resource, 'queryset',
self.get_model().objects.all()))
def get_ordering(self):
"""
Return the ordering that should be used when listing instances.
"""
return getattr(self, 'ordering',
getattr(self.resource, 'ordering',
None))
# Underlying instance API...
def get_instance(self, *args, **kwargs):
"""
Return a model instance or None.
"""
model = self.get_model()
queryset = self.get_queryset()
try: try:
return queryset.get(**kwargs) self.resource.retrieve(request, *args, **kwargs)
except model.DoesNotExist: except self.resource.DoesNotExist:
return None raise ErrorResponse(status.HTTP_404_NOT_FOUND)
return self.resource.instance
def create_instance(self, *args, **kwargs):
model = self.get_model()
m2m_data = {} class PostResourceMixin(object):
for field in model._meta.many_to_many:
if field.name in kwargs:
m2m_data[field.name] = (
field.m2m_reverse_field_name(), kwargs[field.name]
)
del kwargs[field.name]
instance = model(**kwargs) def post(self, request, *args, **kwargs):
instance.save() self.resource.create(request, *args, **kwargs)
self.resource.update(self.CONTENT, request, *args, **kwargs)
headers = {'Location': self.resource.get_url()}
return Response(status.HTTP_201_CREATED, self.resource.instance, headers)
for fieldname in m2m_data:
manager = getattr(instance, fieldname)
if hasattr(manager, 'add'): class PutResourceMixin(object):
manager.add(*m2m_data[fieldname][1])
else:
data = {}
data[manager.source_field_name] = instance
for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save()
return instance
def update_instance(self, instance, *args, **kwargs):
for (key, val) in kwargs.items():
setattr(instance, key, val)
instance.save()
return instance
def delete_instance(self, instance, *args, **kwargs):
instance.delete()
return instance
def list_instances(self, *args, **kwargs):
queryset = self.get_queryset()
ordering = self.get_ordering()
if ordering:
queryset = queryset.order_by(ordering)
return queryset.filter(**kwargs)
# Request/Response layer...
def _get_url_kwargs(self, kwargs):
format_arg = BaseRenderer._FORMAT_QUERY_PARAM
if format_arg in kwargs:
kwargs = kwargs.copy()
del kwargs[format_arg]
return kwargs
def _get_content_kwargs(self, kwargs):
return dict(self._get_url_kwargs(kwargs).items() +
self.CONTENT.items())
def read(self, request, *args, **kwargs):
kwargs = self._get_url_kwargs(kwargs)
instance = self.get_instance(**kwargs)
if instance is None:
raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {})
return instance
def update(self, request, *args, **kwargs):
kwargs = self._get_url_kwargs(kwargs)
instance = self.get_instance(**kwargs)
kwargs = self._get_content_kwargs(kwargs)
if instance:
instance = self.update_instance(instance, **kwargs)
else:
instance = self.create_instance(**kwargs)
return instance
def create(self, request, *args, **kwargs):
kwargs = self._get_content_kwargs(kwargs)
instance = self.create_instance(**kwargs)
def put(self, request, *args, **kwargs):
headers = {} headers = {}
try: try:
headers['Location'] = self.resource(self).url(instance) self.resource.retrieve(request, *args, **kwargs)
except: # TODO: _SkipField should not really happen. status_code = status.HTTP_204_NO_CONTENT
pass except self.resource.DoesNotExist:
self.resource.create(request, *args, **kwargs)
status_code = status.HTTP_201_CREATED
self.resource.update(self.CONTENT, request, *args, **kwargs)
return Response(status_code, self.resource.instance, {})
return Response(status.HTTP_201_CREATED, instance, headers)
def destroy(self, request, *args, **kwargs): class DeleteResourceMixin(object):
kwargs = self._get_url_kwargs(kwargs)
instance = self.delete_instance(**kwargs)
if not instance:
raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {})
return instance def delete(self, request, *args, **kwargs):
try:
self.resource.retrieve(request, *args, **kwargs)
except self.resource.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND)
self.resource.delete(request, *args, **kwargs)
return
def list(self, request, *args, **kwargs):
return self.list_instances(**kwargs) class ListResourceMixin(object):
def get(self, request, *args, **kwargs):
return self.resource.list(request, *args, **kwargs)
########## Pagination Mixins ########## ########## Pagination Mixins ##########
@ -705,7 +621,7 @@ class PaginatorMixin(object):
# We don't want to paginate responses for anything other than GET # We don't want to paginate responses for anything other than GET
# requests # requests
if self.method.upper() != 'GET': if self.method.upper() != 'GET':
return self._resource.filter_response(obj) return self.resource.filter_response(obj)
paginator = Paginator(obj, self.get_limit()) paginator = Paginator(obj, self.get_limit())
@ -721,7 +637,7 @@ class PaginatorMixin(object):
page = paginator.page(page_num) page = paginator.page(page_num)
serialized_object_list = self._resource.filter_response(page.object_list) serialized_object_list = self.resource.filter_response(page.object_list)
serialized_page_info = self.serialize_page_info(page) serialized_page_info = self.serialize_page_info(page)
serialized_page_info['results'] = serialized_object_list serialized_page_info['results'] = serialized_object_list

View File

@ -6,6 +6,14 @@ from djangorestframework.response import ErrorResponse
from djangorestframework.serializer import Serializer, _SkipField from djangorestframework.serializer import Serializer, _SkipField
def bound_resource_required(meth):
def _decorated(self, *args, **kwargs):
if not self.is_bound():
raise Exception("resource needs to be bound") #TODO: what exception?
return meth(self, *args, **kwargs)
return _decorated
class BaseResource(Serializer): class BaseResource(Serializer):
""" """
Base class for all Resource classes, which simply defines the interface Base class for all Resource classes, which simply defines the interface
@ -15,9 +23,13 @@ class BaseResource(Serializer):
include = () include = ()
exclude = () exclude = ()
def __init__(self, view=None, depth=None, stack=[], **kwargs): # TODO: Inheritance, like for models
class DoesNotExist(Exception): pass
def __init__(self, instance=None, view=None, depth=None, stack=[], **kwargs):
super(BaseResource, self).__init__(depth, stack, **kwargs) super(BaseResource, self).__init__(depth, stack, **kwargs)
self.view = view self.view = view
self.instance = instance
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
@ -33,6 +45,27 @@ class BaseResource(Serializer):
""" """
return self.serialize(obj) return self.serialize(obj)
def retrieve(self, request, *args, **kwargs):
raise NotImplementedError()
def create(self, request, *args, **kwargs):
raise NotImplementedError()
@bound_resource_required
def update(self, data, request, *args, **kwargs):
raise NotImplementedError()
@bound_resource_required
def delete(self, request, *args, **kwargs):
raise NotImplementedError()
@bound_resource_required
def get_url(self):
raise NotImplementedError()
def is_bound(self):
return not self.instance is None
class Resource(BaseResource): class Resource(BaseResource):
""" """
@ -202,7 +235,6 @@ class FormResource(Resource):
return form return form
def get_bound_form(self, data=None, files=None, method=None): def get_bound_form(self, data=None, files=None, method=None):
""" """
Given some content return a Django form bound to that content. Given some content return a Django form bound to that content.
@ -288,16 +320,136 @@ class ModelResource(FormResource):
is not set. is not set.
""" """
def __init__(self, view=None, depth=None, stack=[], **kwargs): def __init__(self, instance=None, view=None, depth=None, stack=[], **kwargs):
""" """
Allow :attr:`form` and :attr:`model` attributes set on the Allow :attr:`form` and :attr:`model` attributes set on the
:class:`View` to override the :attr:`form` and :attr:`model` :class:`View` to override the :attr:`form` and :attr:`model`
attributes set on the :class:`Resource`. attributes set on the :class:`Resource`.
""" """
super(ModelResource, self).__init__(view, depth, stack, **kwargs) super(ModelResource, self).__init__(instance=instance, view=view, depth=depth, stack=stack, **kwargs)
self.model = getattr(view, 'model', None) or self.model self.model = getattr(view, 'model', None) or self.model
def retrieve(self, request, *args, **kwargs):
"""
Return a model instance or None.
"""
model = self.get_model()
queryset = self.get_queryset()
kwargs = self._clean_url_kwargs(kwargs)
try:
instance = queryset.get(**kwargs)
except model.DoesNotExist:
raise self.DoesNotExist
self.instance = instance
return self.instance
def create(self, request, *args, **kwargs):
model = self.get_model()
kwargs = self._clean_url_kwargs(kwargs)
self.instance = model(**kwargs)
self.instance.save()
return self.instance
@bound_resource_required
def update(self, data, request, *args, **kwargs):
model = self.get_model()
kwargs = self._clean_url_kwargs(kwargs)
data = dict(data, **kwargs)
# Updating many to many relationships
# TODO: code very hard to understand
m2m_data = {}
for field in model._meta.many_to_many:
if field.name in data:
m2m_data[field.name] = (
field.m2m_reverse_field_name(), data[field.name]
)
del data[field.name]
for fieldname in m2m_data:
manager = getattr(self.instance, fieldname)
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
rdata = {}
rdata[manager.source_field_name] = self.instance
for related_item in m2m_data[fieldname][1]:
rdata[m2m_data[fieldname][0]] = related_item
manager.through(**rdata).save()
# Updating other fields
for (key, val) in data.items():
setattr(self.instance, key, val)
self.instance.save()
return self.instance
@bound_resource_required
def delete(self, request, *args, **kwargs):
self.instance.delete()
return self.instance
def list(self, request, *args, **kwargs):
# TODO: QuerysetResource instead !?
kwargs = self._clean_url_kwargs(kwargs)
queryset = self.get_queryset()
ordering = self.get_ordering()
if ordering:
queryset = queryset.order_by(ordering)
return queryset.filter(**kwargs)
@bound_resource_required
def get_url(self):
"""
Attempts to reverse resolve the url of the given model *instance* for
this resource.
Requires a ``View`` with :class:`mixins.InstanceMixin` to have been
created for this resource.
This method can be overridden if you need to set the resource url
reversing explicitly.
"""
if not hasattr(self, 'view_callable'):
raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
# pattern = tuple_item[1]
# Note: defaults = tuple_item[2] for django >= 1.3
for result, params in possibility:
# instance_attrs = dict([ (param, getattr(instance, param))
# for param in params
# if hasattr(instance, param) ])
instance_attrs = {}
for param in params:
if not hasattr(self.instance, param):
continue
attr = getattr(self.instance, param)
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
raise _SkipField
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
Given some content as input return some cleaned, validated content. Given some content as input return some cleaned, validated content.
@ -318,7 +470,7 @@ class ModelResource(FormResource):
`{field name as string: list of errors as strings}`. `{field name as string: list of errors as strings}`.
""" """
return self._validate(data, files, return self._validate(data, files,
allowed_extra_fields=self._property_fields_set) allowed_extra_fields=self._property_fields_set())
def get_bound_form(self, data=None, files=None, method=None): def get_bound_form(self, data=None, files=None, method=None):
""" """
@ -354,52 +506,6 @@ class ModelResource(FormResource):
return form() return form()
def url(self, instance):
"""
Attempts to reverse resolve the url of the given model *instance* for
this resource.
Requires a ``View`` with :class:`mixins.InstanceMixin` to have been
created for this resource.
This method can be overridden if you need to set the resource url
reversing explicitly.
"""
if not hasattr(self, 'view_callable'):
raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
# pattern = tuple_item[1]
# Note: defaults = tuple_item[2] for django >= 1.3
for result, params in possibility:
# instance_attrs = dict([ (param, getattr(instance, param))
# for param in params
# if hasattr(instance, param) ])
instance_attrs = {}
for param in params:
if not hasattr(instance, param):
continue
attr = getattr(instance, param)
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
raise _SkipField
@property @property
def _model_fields_set(self): def _model_fields_set(self):
""" """
@ -412,8 +518,6 @@ class ModelResource(FormResource):
return model_fields - set(as_tuple(self.exclude)) return model_fields - set(as_tuple(self.exclude))
@property
def _property_fields_set(self): def _property_fields_set(self):
""" """
Returns a set containing the names of validated properties on the model. Returns a set containing the names of validated properties on the model.
@ -426,3 +530,35 @@ class ModelResource(FormResource):
return property_fields & set(self.fields) return property_fields & set(self.fields)
return property_fields.union(set(self.include)) - set(self.exclude) return property_fields.union(set(self.include)) - set(self.exclude)
def get_model(self):
"""
Return the model class for this view.
"""
return getattr(self, 'model', getattr(self.view, 'model', None))
def get_queryset(self):
"""
Return the queryset that should be used when retrieving or listing
instances.
"""
return getattr(self, 'queryset',
getattr(self.view, 'queryset',
self.get_model().objects.all()))
def get_ordering(self):
"""
Return the ordering that should be used when listing instances.
"""
return getattr(self, 'ordering',
getattr(self.view, 'ordering',
None))
def _clean_url_kwargs(self, kwargs):
# TODO: probably this functionality shouldn't be there
from djangorestframework.renderers import BaseRenderer
format_arg = BaseRenderer._FORMAT_QUERY_PARAM
if format_arg in kwargs:
kwargs = kwargs.copy()
del kwargs[format_arg]
return kwargs

View File

@ -4,120 +4,12 @@ from django.utils import simplejson as json
from djangorestframework import status from djangorestframework import status
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from djangorestframework.mixins import PaginatorMixin, ModelMixin from djangorestframework.mixins import PaginatorMixin
from djangorestframework.resources import ModelResource
from djangorestframework.response import Response from djangorestframework.response import Response
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase from djangorestframework.tests.testcases import TestModelsTestCase
from djangorestframework.views import View from djangorestframework.views import View
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_creation(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
form_data = {'name': 'foo'}
request = self.req.post('/groups', data=form_data)
mixin = ModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
response = mixin.create(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, CustomUser.objects.count())
self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(3, CustomUser.objects.count())
self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
class MockPaginatorView(PaginatorMixin, View): class MockPaginatorView(PaginatorMixin, View):
total = 60 total = 60

View File

@ -23,12 +23,12 @@ class CustomUserResource(ModelResource):
model = CustomUser model = CustomUser
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), url(r'^users/$', ListOrCreateModelView.as_view(resource_class=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), url(r'^customusers/$', ListOrCreateModelView.as_view(resource_class=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), url(r'^groups/$', ListOrCreateModelView.as_view(resource_class=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=GroupResource)),
) )

View File

@ -0,0 +1,132 @@
from djangorestframework.views import View
from djangorestframework.resources import ModelResource
from djangorestframework.tests.testcases import TestModelsTestCase
from djangorestframework.compat import RequestFactory
from djangorestframework.tests.models import CustomUser
from django.contrib.auth.models import Group, User
class MockView(View):
"""This is a basic mock view"""
pass
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_create(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
request = self.req.post('/groups', data={})
args = []
kwargs = {'name': 'foo'}
resource = GroupResource(view=MockView.as_view())
resource.create(request, *args, **kwargs)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', resource.instance.name)
def test_update(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
group = Group(name='foo')
group.save()
request = self.req.post('/groups', data={})
args = []
kwargs = {}
data = {'name': 'bla'}
resource = GroupResource(instance=group, view=MockView.as_view())
resource.update(data, request, *args, **kwargs)
self.assertEquals('bla', resource.instance.name)
def test_update_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
user = User(username='bar')
user.save()
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
args = []
kwargs = {}
cleaned_data = dict(form_data, groups=[group])
resource = UserResource(instance=user, view=MockView.as_view())
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(1, resource.instance.groups.count())
self.assertEquals('foo', resource.instance.groups.all()[0].name)
def test_update_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
user = User(username='bar')
user.save()
form_data = {'groups': []}
request = self.req.post('/groups', data=form_data)
args = []
kwargs = {}
cleaned_data = dict(form_data, groups=[])
resource = UserResource(instance=user, view=MockView.as_view())
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(0, resource.instance.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data, groups=[group])
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(1, resource.instance.groups.count())
self.assertEquals('foo1', resource.instance.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data, groups=[group, group2])
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(2, resource.instance.groups.count())
self.assertEquals('foo1', resource.instance.groups.all()[0].name)
self.assertEquals('foo2', resource.instance.groups.all()[1].name)

View File

@ -23,7 +23,7 @@ class MockView_PerViewThrottling(MockView):
class MockView_PerResourceThrottling(MockView): class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, ) permissions = ( PerResourceThrottling, )
resource = FormResource resource_class = FormResource
class MockView_MinuteThrottling(MockView): class MockView_MinuteThrottling(MockView):
throttle = '3/min' throttle = '3/min'

View File

@ -83,7 +83,7 @@ class TestNonFieldErrors(TestCase):
view = MockView() view = MockView()
content = {'field1': 'example1', 'field2': 'example2'} content = {'field1': 'example1', 'field2': 'example2'}
try: try:
MockResource(view).validate_request(content, None) MockResource(view=view).validate_request(content, None)
except ErrorResponse, exc: except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'errors': [MockForm.ERROR_TEXT]}) self.assertEqual(exc.response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
else: else:
@ -187,77 +187,77 @@ class TestFormValidation(TestCase):
# Tests on FormResource # Tests on FormResource
def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_form_validation_failure_raises_response_exception(self): def test_form_validation_failure_raises_response_exception(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failure_raises_response_exception(validator) self.validation_failure_raises_response_exception(validator)
def test_validation_does_not_allow_extra_fields_by_default(self): def test_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_does_not_allow_extra_fields_by_default(validator) self.validation_does_not_allow_extra_fields_by_default(validator)
def test_validation_allows_extra_fields_if_explicitly_set(self): def test_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator) self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_validation_does_not_require_extra_fields_if_explicitly_set(self): def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator) self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_validation_failed_due_to_no_content_returns_appropriate_message(self): def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator) self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_validation_failed_due_to_field_error_returns_appropriate_message(self): def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator) self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
# Same tests on ModelResource # Same tests on ModelResource
def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_modelform_validation_failure_raises_response_exception(self): def test_modelform_validation_failure_raises_response_exception(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failure_raises_response_exception(validator) self.validation_failure_raises_response_exception(validator)
def test_modelform_validation_does_not_allow_extra_fields_by_default(self): def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_does_not_allow_extra_fields_by_default(validator) self.validation_does_not_allow_extra_fields_by_default(validator)
def test_modelform_validation_allows_extra_fields_if_explicitly_set(self): def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator) self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator) self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator) self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator) self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
@ -280,7 +280,7 @@ class TestModelFormValidator(TestCase):
class MockView(View): class MockView(View):
resource = MockResource resource = MockResource
self.validator = MockResource(MockView) self.validator = MockResource(view=MockView)
def test_property_fields_are_allowed_on_model_forms(self): def test_property_fields_are_allowed_on_model_forms(self):

View File

@ -44,8 +44,8 @@ urlpatterns = patterns('djangorestframework.utils.staticviews',
url(r'^accounts/logout$', 'api_logout'), url(r'^accounts/logout$', 'api_logout'),
url(r'^mock/$', MockView.as_view()), url(r'^mock/$', MockView.as_view()),
url(r'^resourcemock/$', ResourceMockView.as_view()), url(r'^resourcemock/$', ResourceMockView.as_view()),
url(r'^model/$', ListOrCreateModelView.as_view(resource=MockResource)), url(r'^model/$', ListOrCreateModelView.as_view(resource_class=MockResource)),
url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MockResource)), url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource_class=MockResource)),
) )
class BaseViewTests(TestCase): class BaseViewTests(TestCase):

View File

@ -19,9 +19,14 @@ def get_name(view):
if getattr(view, 'cls_instance', None): if getattr(view, 'cls_instance', None):
view = view.cls_instance view = view.cls_instance
# If the view seems to have a resource class, we get it
resource_class = None
if hasattr(view, 'get_resource_class'):
resource_class = view.get_resource_class()
# If this view has a resource that's been overridden, then use that resource for the name # If this view has a resource that's been overridden, then use that resource for the name
if getattr(view, 'resource', None) not in (None, Resource, FormResource, ModelResource): if resource_class not in (None, Resource, FormResource, ModelResource):
name = view.resource.__name__ name = resource_class.__name__
# Chomp of any non-descriptive trailing part of the resource class name # Chomp of any non-descriptive trailing part of the resource class name
if name.endswith('Resource') and name != 'Resource': if name.endswith('Resource') and name != 'Resource':
@ -63,10 +68,14 @@ def get_description(view):
if getattr(view, 'cls_instance', None): if getattr(view, 'cls_instance', None):
view = view.cls_instance view = view.cls_instance
# If the view seems to have a resource class, we get it
resource_class = None
if hasattr(view, 'get_resource_class'):
resource_class = view.get_resource_class()
# If this view has a resource that's been overridden, then use the resource's doctring # If this view has a resource that's been overridden, then use the resource's doctring
if getattr(view, 'resource', None) not in (None, Resource, FormResource, ModelResource): if resource_class not in (None, Resource, FormResource, ModelResource):
doc = view.resource.__doc__ doc = resource_class.__doc__
# Otherwise use the view doctring # Otherwise use the view doctring
elif getattr(view, '__doc__', None): elif getattr(view, '__doc__', None):

View File

@ -35,7 +35,7 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
The resource to use when validating requests and filtering responses, The resource to use when validating requests and filtering responses,
or `None` to use default behaviour. or `None` to use default behaviour.
""" """
resource = None resource_class = None
""" """
List of renderers the resource can serialize the response with, ordered by preference. List of renderers the resource can serialize the response with, ordered by preference.
@ -178,14 +178,16 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
return ret return ret
class ModelView(ModelMixin, View): class ModelView(View):
""" """
A RESTful view that maps to a model in the database. A RESTful view that maps to a model in the database.
""" """
resource = resources.ModelResource resource_class = resources.ModelResource
model = None
class InstanceModelView(InstanceMixin, ModelView): class InstanceModelView(GetResourceMixin, PutResourceMixin, DeleteResourceMixin,
InstanceMixin, ModelView):
""" """
A view which provides default operations for read/update/delete against a A view which provides default operations for read/update/delete against a
model instance. This view is also treated as the Canonical identifier model instance. This view is also treated as the Canonical identifier
@ -193,27 +195,18 @@ class InstanceModelView(InstanceMixin, ModelView):
""" """
_suffix = 'Instance' _suffix = 'Instance'
get = ModelMixin.read
put = ModelMixin.update
delete = ModelMixin.destroy
class ListModelView(ListResourceMixin, ModelView):
class ListModelView(ModelView):
""" """
A view which provides default operations for list, against a model in the A view which provides default operations for list, against a model in the
database. database.
""" """
_suffix = 'List' _suffix = 'List'
get = ModelMixin.list
class ListOrCreateModelView(PostResourceMixin, ListResourceMixin, ModelView):
class ListOrCreateModelView(ModelView):
""" """
A view which provides default operations for list and create, against a A view which provides default operations for list and create, against a
model in the database. model in the database.
""" """
_suffix = 'List' _suffix = 'List'
get = ModelMixin.list
post = ModelMixin.create