diff --git a/djangorestframework/mixins.py b/djangorestframework/mixins.py index 4a4539574..90c883072 100644 --- a/djangorestframework/mixins.py +++ b/djangorestframework/mixins.py @@ -181,7 +181,7 @@ class RequestMixin(object): return parser.parse(stream) raise ErrorResponse(status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, - {'detail': 'Unsupported media type in request \'%s\'.' % + {'detail': 'Unsupported media type in request \'%s\'.' % content_type}) @property @@ -454,9 +454,6 @@ class ModelMixin(object): If a *ModelMixin is going to retrive an instance (or queryset) using args and kwargs passed by as URL arguments, it should provied arguments to objects.get and objects.filter methods wrapped in by `build_query` - - If a *ModelMixin is going to create/update an instance get_instance_data - handles the instance data creation/preaparation. """ queryset = None @@ -477,7 +474,104 @@ class ModelMixin(object): return kwargs - def get_instance_data(self, model, content, **kwargs): + def get_queryset(self): + """ + Return the queryset for this view. + """ + return getattr(self.resource, 'queryset', + self.resource.model.objects.all()) + + def get_ordering(self): + """ + Return the ordering for this view. + """ + return getattr(self.resource, 'ordering', None) + + +class InstanceReaderMixin (object): + """ + Assume a single instance for the view. Caches the instance object on self. + """ + + def get_instance(self): + """ + Return the instance for this view. Raises a DoesNotExist error if + instance does not exist. + """ + if not hasattr(self, 'model_instance'): + query_kwargs = self.get_query_kwargs( + self.request, *self.args, **self.kwargs) + self.model_instance = self.get_queryset().get(**query_kwargs) + return self.model_instance + + def get_instance_or_404(self): + """ + Return the instance for this view, or raise a 404 response if the + instance is not found. + """ + model = self.resource.model + + try: + return self.get_instance() + except model.DoesNotExist: + raise ErrorResponse(status.HTTP_404_NOT_FOUND) + +class InstanceWriterMixin (object): + """ + Abstracts out the logic for applying many-to-many data to a single instance. + """ + + def _separate_m2m_data_from_content(self, model, content): + """ + Split the view's CONTENT into atomic data and many-to-many data. + Retrns a pair of (non_m2m_data, m2m_data) + + Arguments: + - model: model class (django.db.models.Model subclass) to work with + - content: a dictionary with instance data + """ + # Copy the dict to keep it intact + content = dict(content) + m2m_data = {} + + for field in model._meta.many_to_many: + if field.name in content: + m2m_data[field.name] = ( + field.m2m_reverse_field_name(), content[field.name] + ) + del content[field.name] + + non_m2m_data = content + return non_m2m_data, m2m_data + + def _set_m2m_data(self, instance, m2m_data): + """ + Apply the many-to-many data to the given instance. + + Arguments: + - instance: model instance to work with + - m2m_data: a mapping from fieldname to list of identifiers of related + objects + """ + for fieldname in m2m_data: + manager = getattr(instance, fieldname) + + # If we are updating an existing model, we want to clear out + # existing relationships. + if hasattr(manager, 'clear'): + manager.clear() + + if hasattr(manager, 'add'): + manager.add(*m2m_data[fieldname][1]) + else: + data = {} + data[manager.source_field_name] = instance + + for related_item in m2m_data[fieldname][1]: + data[m2m_data[fieldname][0]] = related_item + manager.through(**data).save() + + def _get_instance_data(self, model, content, **kwargs): """ Returns the dict with the data for model instance creation/update. @@ -504,117 +598,104 @@ class ModelMixin(object): return all_kw_args - def get_instance(self, **kwargs): + def create_instance(self): """ - Get a model instance for read/update/delete requests. + Create the instance for this view. """ - return self.get_queryset().get(**kwargs) - - def get_queryset(self): - """ - Return the queryset for this view. - """ - return getattr(self.resource, 'queryset', - self.resource.model.objects.all()) - - def get_ordering(self): - """ - Return the ordering for this view. - """ - return getattr(self.resource, 'ordering', None) - - -class ReadModelMixin(ModelMixin): - """ - Behavior to read a `model` instance on GET requests - """ - def get(self, request, *args, **kwargs): - model = self.resource.model - query_kwargs = self.get_query_kwargs(request, *args, **kwargs) - - try: - self.model_instance = self.get_instance(**query_kwargs) - except model.DoesNotExist: - raise ErrorResponse(status.HTTP_404_NOT_FOUND) - - return self.model_instance - - -class CreateModelMixin(ModelMixin): - """ - Behavior to create a `model` instance on POST requests - """ - def post(self, request, *args, **kwargs): model = self.resource.model - # Copy the dict to keep self.CONTENT intact - content = dict(self.CONTENT) - m2m_data = {} + content, m2m_data = self._separate_m2m_data_from_content(model, + self.CONTENT) + instance_data = self._get_instance_data(model, content, + *self.args, **self.kwargs) - for field in model._meta.many_to_many: - if field.name in content: - m2m_data[field.name] = ( - field.m2m_reverse_field_name(), content[field.name] - ) - del content[field.name] - - instance = model(**self.get_instance_data(model, content, *args, **kwargs)) + instance = model(**instance_data) instance.save() - for fieldname in m2m_data: - manager = getattr(instance, fieldname) + self._set_m2m_data(instance, m2m_data) + return instance - if hasattr(manager, 'add'): - manager.add(*m2m_data[fieldname][1]) - else: - data = {} - data[manager.source_field_name] = instance - - for related_item in m2m_data[fieldname][1]: - data[m2m_data[fieldname][0]] = related_item - manager.through(**data).save() - - headers = {} - if hasattr(self.resource, 'url'): - headers['Location'] = self.resource(self).url(instance) - return Response(status.HTTP_201_CREATED, instance, headers) - - -class UpdateModelMixin(ModelMixin): - """ - Behavior to update a `model` instance on PUT requests - """ - def put(self, request, *args, **kwargs): + def create_or_update_instance(self): + """ + Update the instance for this view, or create it if it does not yet + exist. Assumes the view is also an InstanceReaderMixin + """ model = self.resource.model - query_kwargs = self.get_query_kwargs(request, *args, **kwargs) + + instance_is_new = False + content, m2m_data = self._separate_m2m_data_from_content(model, + self.CONTENT) + instance_data = self._get_instance_data(model, content, + *self.args, **self.kwargs) # TODO: update on the url of a non-existing resource url doesn't work # correctly at the moment - will end up with a new url try: - self.model_instance = self.get_instance(**query_kwargs) + instance = self.get_instance() + + for (key, val) in instance_data.items(): + setattr(instance, key, val) - for (key, val) in self.CONTENT.items(): - setattr(self.model_instance, key, val) except model.DoesNotExist: - self.model_instance = model(**self.get_instance_data(model, self.CONTENT, *args, **kwargs)) - self.model_instance.save() - return self.model_instance + instance_is_new = True + instance = model(**instance_data) + + instance.save() + self._set_m2m_data(instance, m2m_data) + + return instance, instance_is_new + + def delete_instance(self): + """ + Delete the instance for this view. Assumes the view is also an + InstanceReaderMixin. + """ + instance = self.get_instance_or_404() + instance.delete() -class DeleteModelMixin(ModelMixin): +class ReadModelMixin(ModelMixin, InstanceReaderMixin): + """ + Behavior to read a `model` instance on GET requests + """ + def get(self, request, *args, **kwargs): + instance = self.get_instance_or_404() + return instance + + +class CreateModelMixin(ModelMixin, InstanceWriterMixin): + """ + Behavior to create a `model` instance on POST requests + """ + def post(self, request, *args, **kwargs): + instance = self.create_instance() + + headers = {} + if hasattr(self.resource, 'url'): + headers['Location'] = self.resource(self).url(instance) + + return Response(status.HTTP_201_CREATED, instance, headers) + + +class UpdateModelMixin(ModelMixin, InstanceReaderMixin, InstanceWriterMixin): + """ + Behavior to update a `model` instance on PUT requests + """ + def put(self, request, *args, **kwargs): + instance, instance_is_new = self.create_or_update_instance() + + if instance_is_new: + return Response(status.HTTP_201_CREATED, instance) + else: + return instance + + +class DeleteModelMixin(ModelMixin, InstanceReaderMixin, InstanceWriterMixin): """ Behavior to delete a `model` instance on DELETE requests """ def delete(self, request, *args, **kwargs): - model = self.resource.model - query_kwargs = self.get_query_kwargs(request, *args, **kwargs) - - try: - instance = self.get_instance(**query_kwargs) - except model.DoesNotExist: - raise ErrorResponse(status.HTTP_404_NOT_FOUND, None, {}) - - instance.delete() + self.delete_instance() return diff --git a/djangorestframework/tests/mixins.py b/djangorestframework/tests/mixins.py index 8268fdca7..2727b2c13 100644 --- a/djangorestframework/tests/mixins.py +++ b/djangorestframework/tests/mixins.py @@ -4,7 +4,9 @@ from django.utils import simplejson as json from djangorestframework import status from djangorestframework.compat import RequestFactory from django.contrib.auth.models import Group, User -from djangorestframework.mixins import CreateModelMixin, PaginatorMixin, ReadModelMixin +from djangorestframework.mixins import (CreateModelMixin, DeleteModelMixin, + PaginatorMixin, ReadModelMixin, + UpdateModelMixin) from djangorestframework.resources import ModelResource from djangorestframework.response import Response, ErrorResponse from djangorestframework.tests.models import CustomUser @@ -30,6 +32,10 @@ class TestModelRead(TestModelsTestCase): mixin = ReadModelMixin() mixin.resource = GroupResource + mixin.request = request + mixin.args = () + mixin.kwargs = {'id': group.id} + response = mixin.get(request, id=group.id) self.assertEquals(group.name, response.name) @@ -41,6 +47,10 @@ class TestModelRead(TestModelsTestCase): mixin = ReadModelMixin() mixin.resource = GroupResource + mixin.request = request + mixin.args = () + mixin.kwargs = {'id': 12345} + self.assertRaises(ErrorResponse, mixin.get, request, id=12345) @@ -63,6 +73,10 @@ class TestModelCreation(TestModelsTestCase): mixin.resource = GroupResource mixin.CONTENT = form_data + mixin.request = request + mixin.args = () + mixin.kwargs = {} + response = mixin.post(request) self.assertEquals(1, Group.objects.count()) self.assertEquals('foo', response.cleaned_content.name) @@ -89,6 +103,10 @@ class TestModelCreation(TestModelsTestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data + mixin.request = request + mixin.args = () + mixin.kwargs = {} + response = mixin.post(request) self.assertEquals(1, User.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) @@ -112,6 +130,10 @@ class TestModelCreation(TestModelsTestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data + mixin.request = request + mixin.args = () + mixin.kwargs = {} + response = mixin.post(request) self.assertEquals(1, CustomUser.objects.count()) self.assertEquals(0, response.cleaned_content.groups.count()) @@ -127,6 +149,10 @@ class TestModelCreation(TestModelsTestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data + mixin.request = request + mixin.args = () + mixin.kwargs = {} + response = mixin.post(request) self.assertEquals(2, CustomUser.objects.count()) self.assertEquals(1, response.cleaned_content.groups.count()) @@ -143,6 +169,10 @@ class TestModelCreation(TestModelsTestCase): mixin.resource = UserResource mixin.CONTENT = cleaned_data + mixin.request = request + mixin.args = () + mixin.kwargs = {} + response = mixin.post(request) self.assertEquals(3, CustomUser.objects.count()) self.assertEquals(2, response.cleaned_content.groups.count()) @@ -150,6 +180,199 @@ class TestModelCreation(TestModelsTestCase): self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name) +class TestModelUpdate(TestModelsTestCase): + """Tests on UpdateModelMixin""" + + def setUp(self): + super(TestModelsTestCase, self).setUp() + self.req = RequestFactory() + + def test_update(self): + group = Group.objects.create(name='my group') + + self.assertEquals(1, Group.objects.count()) + + class GroupResource(ModelResource): + model = Group + + # Update existing + form_data = {'name': 'my renamed group'} + request = self.req.put('/groups/' + str(group.pk), data=form_data) + mixin = UpdateModelMixin() + mixin.resource = GroupResource + mixin.CONTENT = form_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': group.pk} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(1, Group.objects.count()) + self.assertEquals('my renamed group', response.name) + + # Create new + form_data = {'name': 'other group'} + request = self.req.put('/groups/' + str(group.pk + 1), data=form_data) + mixin = UpdateModelMixin() + mixin.resource = GroupResource + mixin.CONTENT = form_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': group.pk + 1} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(2, Group.objects.count()) + self.assertEquals('other group', response.cleaned_content.name) + self.assertEquals(201, response.status) + + def test_update_with_m2m_relation(self): + class UserResource(ModelResource): + model = User + + def url(self, instance): + return "/users/%i" % instance.id + + group = Group(name='foo') + group.save() + + user = User.objects.create_user(username='bar', password='blah', email="bar@example.com") + self.assertEquals(1, User.objects.count()) + self.assertEquals(0, user.groups.count()) + + form_data = { + 'password': 'baz', + 'groups': [group.id] + } + request = self.req.post('/users/' + str(user.pk), data=form_data) + cleaned_data = dict(form_data) + cleaned_data['groups'] = [group] + mixin = UpdateModelMixin() + mixin.resource = UserResource + mixin.CONTENT = cleaned_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': user.pk} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(1, User.objects.count()) + self.assertEquals(1, response.groups.count()) + self.assertEquals('foo', response.groups.all()[0].name) + self.assertEquals('bar', response.username) + + def test_update_with_m2m_relation_through(self): + """ + Tests updating where the m2m relation uses a through table + """ + class UserResource(ModelResource): + model = CustomUser + + def url(self, instance): + return "/customusers/%i" % instance.id + + user = CustomUser.objects.create(username='bar') + + # Update existing resource with empty relation + form_data = {'username': 'bar0', 'groups': []} + request = self.req.put('/users/' + str(user.pk), data=form_data) + cleaned_data = dict(form_data) + cleaned_data['groups'] = [] + mixin = UpdateModelMixin() + mixin.resource = UserResource + mixin.CONTENT = cleaned_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': user.pk} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(1, CustomUser.objects.count()) + self.assertEquals(0, response.groups.count()) + + # Update existing resource with one relation + group = Group(name='foo1') + group.save() + + form_data = {'username': 'bar1', 'groups': [group.id]} + request = self.req.put('/users/' + str(user.pk), data=form_data) + cleaned_data = dict(form_data) + cleaned_data['groups'] = [group] + mixin = UpdateModelMixin() + mixin.resource = UserResource + mixin.CONTENT = cleaned_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': user.pk} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(1, CustomUser.objects.count()) + self.assertEquals(1, response.groups.count()) + self.assertEquals('foo1', response.groups.all()[0].name) + + # Update existing resource with more than one relation + group2 = Group(name='foo2') + group2.save() + + form_data = {'username': 'bar2', 'groups': [group.id, group2.id]} + request = self.req.put('/users/' + str(user.pk), data=form_data) + cleaned_data = dict(form_data) + cleaned_data['groups'] = [group, group2] + mixin = UpdateModelMixin() + mixin.resource = UserResource + mixin.CONTENT = cleaned_data + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': user.pk} + + response = mixin.put(request, **mixin.kwargs) + self.assertEquals(1, CustomUser.objects.count()) + self.assertEquals(2, response.groups.count()) + self.assertEquals('foo1', response.groups.all()[0].name) + self.assertEquals('foo2', response.groups.all()[1].name) + + +class TestModelDelete(TestModelsTestCase): + """Tests on DeleteModelMixin""" + + def setUp(self): + super(TestModelsTestCase, self).setUp() + self.req = RequestFactory() + + def test_delete(self): + group = Group.objects.create(name='my group') + + self.assertEquals(1, Group.objects.count()) + + class GroupResource(ModelResource): + model = Group + + # Delete existing + request = self.req.delete('/groups/' + str(group.pk)) + mixin = DeleteModelMixin() + mixin.resource = GroupResource + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': group.pk} + + response = mixin.delete(request, **mixin.kwargs) + self.assertEquals(0, Group.objects.count()) + + # Delete at non-existing + request = self.req.delete('/groups/' + str(group.pk)) + mixin = DeleteModelMixin() + mixin.resource = GroupResource + + mixin.request = request + mixin.args = () + mixin.kwargs = {'pk': group.pk} + + self.assertRaises(ErrorResponse, mixin.delete, request, **mixin.kwargs) + + class MockPaginatorView(PaginatorMixin, View): total = 60