new-style resources + tests

This commit is contained in:
Sébastien Piquemal 2012-01-03 19:17:25 +02:00
parent d987b745f9
commit 926059dad6
10 changed files with 382 additions and 378 deletions

View File

@ -8,7 +8,6 @@ from django.core.paginator import Paginator
from django.http import HttpResponse from django.http import HttpResponse
from djangorestframework import status from djangorestframework import status
from djangorestframework.renderers import BaseRenderer
from djangorestframework.resources import Resource, FormResource, ModelResource from djangorestframework.resources import Resource, FormResource, ModelResource
from djangorestframework.response import Response, ErrorResponse from djangorestframework.response import Response, ErrorResponse
from djangorestframework.utils import MSIE_USER_AGENT_REGEX from djangorestframework.utils import MSIE_USER_AGENT_REGEX
@ -26,7 +25,11 @@ __all__ = (
# Reverse URL lookup behavior # Reverse URL lookup behavior
'InstanceMixin', 'InstanceMixin',
# Model behavior mixins # Model behavior mixins
'ModelMixin', 'GetResourceMixin',
'PostResourceMixin',
'PutResourceMixin',
'DeleteResourceMixin',
'ListResourceMixin',
) )
@ -429,17 +432,24 @@ class ResourceMixin(object):
""" """
return self.validate_request(self.request.GET) return self.validate_request(self.request.GET)
@property def get_resource_class(self):
def _resource(self):
if self.resource_class: if self.resource_class:
return self.resource_class(self) return self.resource_class
elif getattr(self, 'model', None): elif getattr(self, 'model', None):
return ModelResource(self) return ModelResource
elif getattr(self, 'form', None): elif getattr(self, 'form', None):
return FormResource(self) return FormResource
elif getattr(self, '%s_form' % self.method.lower(), None): elif hasattr(self, 'request') and getattr(self, '%s_form' % self.method.lower(), None):
return FormResource(self) return FormResource
return Resource(self) else:
return Resource
@property
def resource(self):
if not hasattr(self, '_resource'):
resource_class = self.get_resource_class()
self._resource = resource_class(view=self)
return self._resource
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
@ -448,17 +458,17 @@ class ResourceMixin(object):
May raise an :class:`response.ErrorResponse` with status code 400 May raise an :class:`response.ErrorResponse` with status code 400
(Bad Request) on failure. (Bad Request) on failure.
""" """
return self._resource.validate_request(data, files) return self.resource.validate_request(data, files)
def filter_response(self, obj): def filter_response(self, obj):
""" """
Given the response content, filter it into a serializable object. Given the response content, filter it into a serializable object.
""" """
return self._resource.filter_response(obj) return self.resource.filter_response(obj)
def get_bound_form(self, content=None, method=None): def get_bound_form(self, content=None, method=None):
if hasattr(self._resource, 'get_bound_form'): if hasattr(self.resource, 'get_bound_form'):
return self._resource.get_bound_form(content, method=method) return self.resource.get_bound_form(content, method=method)
else: else:
return None return None
@ -479,72 +489,69 @@ class InstanceMixin(object):
associated with this view. associated with this view.
""" """
view = super(InstanceMixin, cls).as_view(**initkwargs) view = super(InstanceMixin, cls).as_view(**initkwargs)
resource = getattr(cls(**initkwargs), 'resource', None) resource_class = getattr(cls(**initkwargs), 'resource_class', None)
if resource: if resource_class:
# We do a little dance when we store the view callable... # We do a little dance when we store the view callable...
# we need to store it wrapped in a 1-tuple, so that inspect will # we need to store it wrapped in a 1-tuple, so that inspect will
# treat it as a function when we later look it up (rather than # treat it as a function when we later look it up (rather than
# turning it into a method). # turning it into a method).
# This makes sure our URL reversing works ok. # This makes sure our URL reversing works ok.
resource.view_callable = (view,) resource_class.view_callable = (view,)
return view return view
########## Resource operation Mixins ########## ########## Resource operation Mixins ##########
class ReadResourceMixin(object): class GetResourceMixin(object):
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
try: try:
resource = self.resource_class.retrieve(request, *args, **kwargs) self.resource.retrieve(request, *args, **kwargs)
except self.resource_class.DoesNotExist: except self.resource.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND) raise ErrorResponse(status.HTTP_404_NOT_FOUND)
return resource return self.resource.instance
class CreateResourceMixin(object): class PostResourceMixin(object):
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
resource = self.resource_class.create(request, *args, **kwargs) self.resource.create(request, *args, **kwargs)
resource.update(self.CONTENT, request, *args, **kwargs) self.resource.update(self.CONTENT, request, *args, **kwargs)
headers = {'Location': resource.get_url()} headers = {'Location': self.resource.get_url()}
return Response(status.HTTP_201_CREATED, resource, headers) return Response(status.HTTP_201_CREATED, self.resource.instance, headers)
class CreateSubResourceMixin(object): class PutResourceMixin(object):
def post(self, request, *args, **kwargs):
sub_resource = self.resource_class.create(request, *args, **kwargs)
sub_resource.update(self.CONTENT, request, *args, **kwargs)
headers = {'Location': sub_resource.get_url()}
return Response(status.HTTP_201_CREATED, sub_resource, headers)
class UpdateResourceMixin(object):
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
headers = {} headers = {}
try: try:
resource = self.resource_class.retrieve(request, *args, **kwargs) self.resource.retrieve(request, *args, **kwargs)
status_code = status.HTTP_204_NO_CONTENT status_code = status.HTTP_204_NO_CONTENT
except self.resource_class.DoesNotExist: except self.resource.DoesNotExist:
resource = self.resource_class.create(request, *args, **kwargs) self.resource.create(request, *args, **kwargs)
status_code = status.HTTP_201_CREATED status_code = status.HTTP_201_CREATED
resource.update(self.CONTENT, request, *args, **kwargs) self.resource.update(self.CONTENT, request, *args, **kwargs)
return Response(status_code, resource, {}) return Response(status_code, self.resource.instance, {})
class DeleteResourceMixin(object): class DeleteResourceMixin(object):
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
try: try:
resource = self.resource_class.retrieve(request, *args, **kwargs) self.resource.retrieve(request, *args, **kwargs)
except self.resource_class.DoesNotExist: except self.resource.DoesNotExist:
raise ErrorResponse(status.HTTP_404_NOT_FOUND) raise ErrorResponse(status.HTTP_404_NOT_FOUND)
resource.delete(request, *args, **kwargs) self.resource.delete(request, *args, **kwargs)
return return
class ListResourceMixin(object):
def get(self, request, *args, **kwargs):
return self.resource.list(request, *args, **kwargs)
########## Pagination Mixins ########## ########## Pagination Mixins ##########
class PaginatorMixin(object): class PaginatorMixin(object):
@ -614,7 +621,7 @@ class PaginatorMixin(object):
# We don't want to paginate responses for anything other than GET # We don't want to paginate responses for anything other than GET
# requests # requests
if self.method.upper() != 'GET': if self.method.upper() != 'GET':
return self._resource.filter_response(obj) return self.resource.filter_response(obj)
paginator = Paginator(obj, self.get_limit()) paginator = Paginator(obj, self.get_limit())
@ -630,7 +637,7 @@ class PaginatorMixin(object):
page = paginator.page(page_num) page = paginator.page(page_num)
serialized_object_list = self._resource.filter_response(page.object_list) serialized_object_list = self.resource.filter_response(page.object_list)
serialized_page_info = self.serialize_page_info(page) serialized_page_info = self.serialize_page_info(page)
serialized_page_info['results'] = serialized_object_list serialized_page_info['results'] = serialized_object_list

View File

@ -6,6 +6,14 @@ from djangorestframework.response import ErrorResponse
from djangorestframework.serializer import Serializer, _SkipField from djangorestframework.serializer import Serializer, _SkipField
def bound_resource_required(meth):
def _decorated(self, *args, **kwargs):
if not self.is_bound():
raise Exception("resource needs to be bound") #TODO: what exception?
return meth(self, *args, **kwargs)
return _decorated
class BaseResource(Serializer): class BaseResource(Serializer):
""" """
Base class for all Resource classes, which simply defines the interface Base class for all Resource classes, which simply defines the interface
@ -16,11 +24,12 @@ class BaseResource(Serializer):
exclude = () exclude = ()
# TODO: Inheritance, like for models # TODO: Inheritance, like for models
class DoesNotExist(Exception) class DoesNotExist(Exception): pass
def __init__(self, depth=None, stack=[], **kwargs): def __init__(self, instance=None, view=None, depth=None, stack=[], **kwargs):
super(BaseResource, self).__init__(depth, stack, **kwargs) super(BaseResource, self).__init__(depth, stack, **kwargs)
self.view = view self.view = view
self.instance = instance
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
@ -36,23 +45,27 @@ class BaseResource(Serializer):
""" """
return self.serialize(obj) return self.serialize(obj)
@classmethod def retrieve(self, request, *args, **kwargs):
def retrieve(cls, request, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
@classmethod def create(self, request, *args, **kwargs):
def create(cls, request, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
@bound_resource_required
def update(self, data, request, *args, **kwargs): def update(self, data, request, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
@bound_resource_required
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
@bound_resource_required
def get_url(self): def get_url(self):
raise NotImplementedError() raise NotImplementedError()
def is_bound(self):
return not self.instance is None
class Resource(BaseResource): class Resource(BaseResource):
""" """
@ -222,7 +235,6 @@ class FormResource(Resource):
return form return form
def get_bound_form(self, data=None, files=None, method=None): def get_bound_form(self, data=None, files=None, method=None):
""" """
Given some content return a Django form bound to that content. Given some content return a Django form bound to that content.
@ -308,16 +320,136 @@ class ModelResource(FormResource):
is not set. is not set.
""" """
def __init__(self, view=None, depth=None, stack=[], **kwargs): def __init__(self, instance=None, view=None, depth=None, stack=[], **kwargs):
""" """
Allow :attr:`form` and :attr:`model` attributes set on the Allow :attr:`form` and :attr:`model` attributes set on the
:class:`View` to override the :attr:`form` and :attr:`model` :class:`View` to override the :attr:`form` and :attr:`model`
attributes set on the :class:`Resource`. attributes set on the :class:`Resource`.
""" """
super(ModelResource, self).__init__(view, depth, stack, **kwargs) super(ModelResource, self).__init__(instance=instance, view=view, depth=depth, stack=stack, **kwargs)
self.model = getattr(view, 'model', None) or self.model self.model = getattr(view, 'model', None) or self.model
def retrieve(self, request, *args, **kwargs):
"""
Return a model instance or None.
"""
model = self.get_model()
queryset = self.get_queryset()
kwargs = self._clean_url_kwargs(kwargs)
try:
instance = queryset.get(**kwargs)
except model.DoesNotExist:
raise self.DoesNotExist
self.instance = instance
return self.instance
def create(self, request, *args, **kwargs):
model = self.get_model()
kwargs = self._clean_url_kwargs(kwargs)
self.instance = model(**kwargs)
self.instance.save()
return self.instance
@bound_resource_required
def update(self, data, request, *args, **kwargs):
model = self.get_model()
kwargs = self._clean_url_kwargs(kwargs)
data = dict(data, **kwargs)
# Updating many to many relationships
# TODO: code very hard to understand
m2m_data = {}
for field in model._meta.many_to_many:
if field.name in data:
m2m_data[field.name] = (
field.m2m_reverse_field_name(), data[field.name]
)
del data[field.name]
for fieldname in m2m_data:
manager = getattr(self.instance, fieldname)
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
rdata = {}
rdata[manager.source_field_name] = self.instance
for related_item in m2m_data[fieldname][1]:
rdata[m2m_data[fieldname][0]] = related_item
manager.through(**rdata).save()
# Updating other fields
for (key, val) in data.items():
setattr(self.instance, key, val)
self.instance.save()
return self.instance
@bound_resource_required
def delete(self, request, *args, **kwargs):
self.instance.delete()
return self.instance
def list(self, request, *args, **kwargs):
# TODO: QuerysetResource instead !?
kwargs = self._clean_url_kwargs(kwargs)
queryset = self.get_queryset()
ordering = self.get_ordering()
if ordering:
queryset = queryset.order_by(ordering)
return queryset.filter(**kwargs)
@bound_resource_required
def get_url(self):
"""
Attempts to reverse resolve the url of the given model *instance* for
this resource.
Requires a ``View`` with :class:`mixins.InstanceMixin` to have been
created for this resource.
This method can be overridden if you need to set the resource url
reversing explicitly.
"""
if not hasattr(self, 'view_callable'):
raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
# pattern = tuple_item[1]
# Note: defaults = tuple_item[2] for django >= 1.3
for result, params in possibility:
# instance_attrs = dict([ (param, getattr(instance, param))
# for param in params
# if hasattr(instance, param) ])
instance_attrs = {}
for param in params:
if not hasattr(self.instance, param):
continue
attr = getattr(self.instance, param)
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
raise _SkipField
def validate_request(self, data, files=None): def validate_request(self, data, files=None):
""" """
Given some content as input return some cleaned, validated content. Given some content as input return some cleaned, validated content.
@ -338,7 +470,7 @@ class ModelResource(FormResource):
`{field name as string: list of errors as strings}`. `{field name as string: list of errors as strings}`.
""" """
return self._validate(data, files, return self._validate(data, files,
allowed_extra_fields=self._property_fields_set) allowed_extra_fields=self._property_fields_set())
def get_bound_form(self, data=None, files=None, method=None): def get_bound_form(self, data=None, files=None, method=None):
""" """
@ -374,52 +506,6 @@ class ModelResource(FormResource):
return form() return form()
def url(self, instance):
"""
Attempts to reverse resolve the url of the given model *instance* for
this resource.
Requires a ``View`` with :class:`mixins.InstanceMixin` to have been
created for this resource.
This method can be overridden if you need to set the resource url
reversing explicitly.
"""
if not hasattr(self, 'view_callable'):
raise _SkipField
# dis does teh magicks...
urlconf = get_urlconf()
resolver = get_resolver(urlconf)
possibilities = resolver.reverse_dict.getlist(self.view_callable[0])
for tuple_item in possibilities:
possibility = tuple_item[0]
# pattern = tuple_item[1]
# Note: defaults = tuple_item[2] for django >= 1.3
for result, params in possibility:
# instance_attrs = dict([ (param, getattr(instance, param))
# for param in params
# if hasattr(instance, param) ])
instance_attrs = {}
for param in params:
if not hasattr(instance, param):
continue
attr = getattr(instance, param)
if isinstance(attr, models.Model):
instance_attrs[param] = attr.pk
else:
instance_attrs[param] = attr
try:
return reverse(self.view_callable[0], kwargs=instance_attrs)
except NoReverseMatch:
pass
raise _SkipField
@property @property
def _model_fields_set(self): def _model_fields_set(self):
""" """
@ -432,8 +518,6 @@ class ModelResource(FormResource):
return model_fields - set(as_tuple(self.exclude)) return model_fields - set(as_tuple(self.exclude))
@property
def _property_fields_set(self): def _property_fields_set(self):
""" """
Returns a set containing the names of validated properties on the model. Returns a set containing the names of validated properties on the model.
@ -447,15 +531,11 @@ class ModelResource(FormResource):
return property_fields.union(set(self.include)) - set(self.exclude) return property_fields.union(set(self.include)) - set(self.exclude)
class ModelMixin(object):
def get_model(self): def get_model(self):
""" """
Return the model class for this view. Return the model class for this view.
""" """
return getattr(self, 'model', self.resource.model) return getattr(self, 'model', getattr(self.view, 'model', None))
def get_queryset(self): def get_queryset(self):
""" """
@ -463,7 +543,7 @@ class ModelMixin(object):
instances. instances.
""" """
return getattr(self, 'queryset', return getattr(self, 'queryset',
getattr(self.resource, 'queryset', getattr(self.view, 'queryset',
self.get_model().objects.all())) self.get_model().objects.all()))
def get_ordering(self): def get_ordering(self):
@ -471,123 +551,14 @@ class ModelMixin(object):
Return the ordering that should be used when listing instances. Return the ordering that should be used when listing instances.
""" """
return getattr(self, 'ordering', return getattr(self, 'ordering',
getattr(self.resource, 'ordering', getattr(self.view, 'ordering',
None)) None))
# Underlying instance API... def _clean_url_kwargs(self, kwargs):
# TODO: probably this functionality shouldn't be there
def get_instance(self, *args, **kwargs): from djangorestframework.renderers import BaseRenderer
"""
Return a model instance or None.
"""
model = self.get_model()
queryset = self.get_queryset()
try:
return queryset.get(**kwargs)
except model.DoesNotExist:
return None
def create_instance(self, *args, **kwargs):
model = self.get_model()
m2m_data = {}
for field in model._meta.many_to_many:
if field.name in kwargs:
m2m_data[field.name] = (
field.m2m_reverse_field_name(), kwargs[field.name]
)
del kwargs[field.name]
instance = model(**kwargs)
instance.save()
for fieldname in m2m_data:
manager = getattr(instance, fieldname)
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
data = {}
data[manager.source_field_name] = instance
for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save()
return instance
def update_instance(self, instance, *args, **kwargs):
for (key, val) in kwargs.items():
setattr(instance, key, val)
instance.save()
return instance
def delete_instance(self, instance, *args, **kwargs):
instance.delete()
return instance
def list_instances(self, *args, **kwargs):
queryset = self.get_queryset()
ordering = self.get_ordering()
if ordering:
queryset = queryset.order_by(ordering)
return queryset.filter(**kwargs)
# Request/Response layer...
def _get_url_kwargs(self, kwargs):
format_arg = BaseRenderer._FORMAT_QUERY_PARAM format_arg = BaseRenderer._FORMAT_QUERY_PARAM
if format_arg in kwargs: if format_arg in kwargs:
kwargs = kwargs.copy() kwargs = kwargs.copy()
del kwargs[format_arg] del kwargs[format_arg]
return kwargs return kwargs
def _get_content_kwargs(self, kwargs):
return dict(self._get_url_kwargs(kwargs).items() +
self.CONTENT.items())
def read(self, request, *args, **kwargs):
kwargs = self._get_url_kwargs(kwargs)
instance = self.get_instance(**kwargs)
if instance is None:
raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {})
return instance
def update(self, request, *args, **kwargs):
kwargs = self._get_url_kwargs(kwargs)
instance = self.get_instance(**kwargs)
kwargs = self._get_content_kwargs(kwargs)
if instance:
instance = self.update_instance(instance, **kwargs)
else:
instance = self.create_instance(**kwargs)
return instance
def create(self, request, *args, **kwargs):
kwargs = self._get_content_kwargs(kwargs)
instance = self.create_instance(**kwargs)
headers = {}
try:
headers['Location'] = self.resource(self).url(instance)
except: # TODO: _SkipField should not really happen.
pass
return Response(status.HTTP_201_CREATED, instance, headers)
def destroy(self, request, *args, **kwargs):
kwargs = self._get_url_kwargs(kwargs)
instance = self.delete_instance(**kwargs)
if not instance:
raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {})
return instance
def list(self, request, *args, **kwargs):
return self.list_instances(**kwargs)

View File

@ -4,120 +4,12 @@ from django.utils import simplejson as json
from djangorestframework import status from djangorestframework import status
from djangorestframework.compat import RequestFactory from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from djangorestframework.mixins import PaginatorMixin, ModelMixin from djangorestframework.mixins import PaginatorMixin
from djangorestframework.resources import ModelResource
from djangorestframework.response import Response from djangorestframework.response import Response
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase from djangorestframework.tests.testcases import TestModelsTestCase
from djangorestframework.views import View from djangorestframework.views import View
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_creation(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
form_data = {'name': 'foo'}
request = self.req.post('/groups', data=form_data)
mixin = ModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
response = mixin.create(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, CustomUser.objects.count())
self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(3, CustomUser.objects.count())
self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
class MockPaginatorView(PaginatorMixin, View): class MockPaginatorView(PaginatorMixin, View):
total = 60 total = 60

View File

@ -23,12 +23,12 @@ class CustomUserResource(ModelResource):
model = CustomUser model = CustomUser
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), url(r'^users/$', ListOrCreateModelView.as_view(resource_class=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'), url(r'^customusers/$', ListOrCreateModelView.as_view(resource_class=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)), url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), url(r'^groups/$', ListOrCreateModelView.as_view(resource_class=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource_class=GroupResource)),
) )

View File

@ -0,0 +1,132 @@
from djangorestframework.views import View
from djangorestframework.resources import ModelResource
from djangorestframework.tests.testcases import TestModelsTestCase
from djangorestframework.compat import RequestFactory
from djangorestframework.tests.models import CustomUser
from django.contrib.auth.models import Group, User
class MockView(View):
"""This is a basic mock view"""
pass
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_create(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
request = self.req.post('/groups', data={})
args = []
kwargs = {'name': 'foo'}
resource = GroupResource(view=MockView.as_view())
resource.create(request, *args, **kwargs)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', resource.instance.name)
def test_update(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
group = Group(name='foo')
group.save()
request = self.req.post('/groups', data={})
args = []
kwargs = {}
data = {'name': 'bla'}
resource = GroupResource(instance=group, view=MockView.as_view())
resource.update(data, request, *args, **kwargs)
self.assertEquals('bla', resource.instance.name)
def test_update_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
user = User(username='bar')
user.save()
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
args = []
kwargs = {}
cleaned_data = dict(form_data, groups=[group])
resource = UserResource(instance=user, view=MockView.as_view())
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(1, resource.instance.groups.count())
self.assertEquals('foo', resource.instance.groups.all()[0].name)
def test_update_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
user = User(username='bar')
user.save()
form_data = {'groups': []}
request = self.req.post('/groups', data=form_data)
args = []
kwargs = {}
cleaned_data = dict(form_data, groups=[])
resource = UserResource(instance=user, view=MockView.as_view())
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(0, resource.instance.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data, groups=[group])
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(1, resource.instance.groups.count())
self.assertEquals('foo1', resource.instance.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data, groups=[group, group2])
resource.update(cleaned_data, request, *args, **kwargs)
self.assertEquals(2, resource.instance.groups.count())
self.assertEquals('foo1', resource.instance.groups.all()[0].name)
self.assertEquals('foo2', resource.instance.groups.all()[1].name)

View File

@ -23,7 +23,7 @@ class MockView_PerViewThrottling(MockView):
class MockView_PerResourceThrottling(MockView): class MockView_PerResourceThrottling(MockView):
permissions = ( PerResourceThrottling, ) permissions = ( PerResourceThrottling, )
resource = FormResource resource_class = FormResource
class MockView_MinuteThrottling(MockView): class MockView_MinuteThrottling(MockView):
throttle = '3/min' throttle = '3/min'

View File

@ -83,7 +83,7 @@ class TestNonFieldErrors(TestCase):
view = MockView() view = MockView()
content = {'field1': 'example1', 'field2': 'example2'} content = {'field1': 'example1', 'field2': 'example2'}
try: try:
MockResource(view).validate_request(content, None) MockResource(view=view).validate_request(content, None)
except ErrorResponse, exc: except ErrorResponse, exc:
self.assertEqual(exc.response.raw_content, {'errors': [MockForm.ERROR_TEXT]}) self.assertEqual(exc.response.raw_content, {'errors': [MockForm.ERROR_TEXT]})
else: else:
@ -187,77 +187,77 @@ class TestFormValidation(TestCase):
# Tests on FormResource # Tests on FormResource
def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_form_validation_failure_raises_response_exception(self): def test_form_validation_failure_raises_response_exception(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failure_raises_response_exception(validator) self.validation_failure_raises_response_exception(validator)
def test_validation_does_not_allow_extra_fields_by_default(self): def test_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_does_not_allow_extra_fields_by_default(validator) self.validation_does_not_allow_extra_fields_by_default(validator)
def test_validation_allows_extra_fields_if_explicitly_set(self): def test_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator) self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_validation_does_not_require_extra_fields_if_explicitly_set(self): def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator) self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_validation_failed_due_to_no_content_returns_appropriate_message(self): def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator) self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_validation_failed_due_to_field_error_returns_appropriate_message(self): def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator) self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockFormResource(self.MockFormView()) validator = self.MockFormResource(view=self.MockFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
# Same tests on ModelResource # Same tests on ModelResource
def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
def test_modelform_validation_failure_raises_response_exception(self): def test_modelform_validation_failure_raises_response_exception(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failure_raises_response_exception(validator) self.validation_failure_raises_response_exception(validator)
def test_modelform_validation_does_not_allow_extra_fields_by_default(self): def test_modelform_validation_does_not_allow_extra_fields_by_default(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_does_not_allow_extra_fields_by_default(validator) self.validation_does_not_allow_extra_fields_by_default(validator)
def test_modelform_validation_allows_extra_fields_if_explicitly_set(self): def test_modelform_validation_allows_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_allows_extra_fields_if_explicitly_set(validator) self.validation_allows_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_does_not_require_extra_fields_if_explicitly_set(validator) self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_no_content_returns_appropriate_message(validator) self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_field_error_returns_appropriate_message(validator) self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
validator = self.MockModelResource(self.MockModelFormView()) validator = self.MockModelResource(view=self.MockModelFormView())
self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
@ -280,7 +280,7 @@ class TestModelFormValidator(TestCase):
class MockView(View): class MockView(View):
resource = MockResource resource = MockResource
self.validator = MockResource(MockView) self.validator = MockResource(view=MockView)
def test_property_fields_are_allowed_on_model_forms(self): def test_property_fields_are_allowed_on_model_forms(self):

View File

@ -44,8 +44,8 @@ urlpatterns = patterns('djangorestframework.utils.staticviews',
url(r'^accounts/logout$', 'api_logout'), url(r'^accounts/logout$', 'api_logout'),
url(r'^mock/$', MockView.as_view()), url(r'^mock/$', MockView.as_view()),
url(r'^resourcemock/$', ResourceMockView.as_view()), url(r'^resourcemock/$', ResourceMockView.as_view()),
url(r'^model/$', ListOrCreateModelView.as_view(resource=MockResource)), url(r'^model/$', ListOrCreateModelView.as_view(resource_class=MockResource)),
url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource=MockResource)), url(r'^model/(?P<pk>[^/]+)/$', InstanceModelView.as_view(resource_class=MockResource)),
) )
class BaseViewTests(TestCase): class BaseViewTests(TestCase):

View File

@ -19,9 +19,14 @@ def get_name(view):
if getattr(view, 'cls_instance', None): if getattr(view, 'cls_instance', None):
view = view.cls_instance view = view.cls_instance
# If the view seems to have a resource class, we get it
resource_class = None
if hasattr(view, 'get_resource_class'):
resource_class = view.get_resource_class()
# If this view has a resource that's been overridden, then use that resource for the name # If this view has a resource that's been overridden, then use that resource for the name
if getattr(view, 'resource', None) not in (None, Resource, FormResource, ModelResource): if resource_class not in (None, Resource, FormResource, ModelResource):
name = view.resource.__name__ name = resource_class.__name__
# Chomp of any non-descriptive trailing part of the resource class name # Chomp of any non-descriptive trailing part of the resource class name
if name.endswith('Resource') and name != 'Resource': if name.endswith('Resource') and name != 'Resource':
@ -63,10 +68,14 @@ def get_description(view):
if getattr(view, 'cls_instance', None): if getattr(view, 'cls_instance', None):
view = view.cls_instance view = view.cls_instance
# If the view seems to have a resource class, we get it
resource_class = None
if hasattr(view, 'get_resource_class'):
resource_class = view.get_resource_class()
# If this view has a resource that's been overridden, then use the resource's doctring # If this view has a resource that's been overridden, then use the resource's doctring
if getattr(view, 'resource', None) not in (None, Resource, FormResource, ModelResource): if resource_class not in (None, Resource, FormResource, ModelResource):
doc = view.resource.__doc__ doc = resource_class.__doc__
# Otherwise use the view doctring # Otherwise use the view doctring
elif getattr(view, '__doc__', None): elif getattr(view, '__doc__', None):

View File

@ -178,14 +178,16 @@ class View(ResourceMixin, RequestMixin, ResponseMixin, AuthMixin, DjangoView):
return ret return ret
class ModelView(ModelMixin, View): class ModelView(View):
""" """
A RESTful view that maps to a model in the database. A RESTful view that maps to a model in the database.
""" """
resource_class = resources.ModelResource resource_class = resources.ModelResource
model = None
class InstanceModelView(InstanceMixin, ModelView): class InstanceModelView(GetResourceMixin, PutResourceMixin, DeleteResourceMixin,
InstanceMixin, ModelView):
""" """
A view which provides default operations for read/update/delete against a A view which provides default operations for read/update/delete against a
model instance. This view is also treated as the Canonical identifier model instance. This view is also treated as the Canonical identifier
@ -193,27 +195,18 @@ class InstanceModelView(InstanceMixin, ModelView):
""" """
_suffix = 'Instance' _suffix = 'Instance'
get = ModelMixin.read
put = ModelMixin.update
delete = ModelMixin.destroy
class ListModelView(ListResourceMixin, ModelView):
class ListModelView(ModelView):
""" """
A view which provides default operations for list, against a model in the A view which provides default operations for list, against a model in the
database. database.
""" """
_suffix = 'List' _suffix = 'List'
get = ModelMixin.list
class ListOrCreateModelView(PostResourceMixin, ListResourceMixin, ModelView):
class ListOrCreateModelView(ModelView):
""" """
A view which provides default operations for list and create, against a A view which provides default operations for list and create, against a
model in the database. model in the database.
""" """
_suffix = 'List' _suffix = 'List'
get = ModelMixin.list
post = ModelMixin.create