This commit is contained in:
GitHub Merge Button 2011-07-11 12:33:31 -07:00
commit 7f0b291528
3 changed files with 137 additions and 8 deletions

View File

@ -5,7 +5,7 @@ classes that can be added to a `View`.
from django.contrib.auth.models import AnonymousUser
from django.db.models.query import QuerySet
from django.db.models.fields.related import RelatedField
from django.db.models.fields.related import ForeignKey
from django.http import HttpResponse
from django.http.multipartparser import LimitBytes
@ -509,21 +509,35 @@ class CreateModelMixin(object):
"""
Behavior to create a `model` instance on POST requests
"""
def post(self, request, *args, **kwargs):
def post(self, request, *args, **kwargs):
model = self.resource.model
# translated 'related_field' kwargs into 'related_field_id'
for related_name in [field.name for field in model._meta.fields if isinstance(field, RelatedField)]:
if kwargs.has_key(related_name):
kwargs[related_name + '_id'] = kwargs[related_name]
del kwargs[related_name]
# Copy the dict to keep self.CONTENT intact
content = dict(self.CONTENT)
m2m_data = {}
for field in model._meta.fields:
if isinstance(field, ForeignKey) and kwargs.has_key(field.name):
# translate 'related_field' kwargs into 'related_field_id'
kwargs[field.name + '_id'] = kwargs[field.name]
del kwargs[field.name]
for field in model._meta.many_to_many:
if content.has_key(field.name):
m2m_data[field.name] = content[field.name]
del content[field.name]
all_kw_args = dict(content.items() + kwargs.items())
all_kw_args = dict(self.CONTENT.items() + kwargs.items())
if args:
instance = model(pk=args[-1], **all_kw_args)
else:
instance = model(**all_kw_args)
instance.save()
for fieldname in m2m_data:
getattr(instance, fieldname).add(*m2m_data[fieldname])
headers = {}
if hasattr(instance, 'get_absolute_url'):
headers['Location'] = self.resource(self).url(instance)

View File

@ -0,0 +1,56 @@
"""Tests for the status module"""
from django.test import TestCase
from djangorestframework import status
from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User
from djangorestframework.mixins import CreateModelMixin
from djangorestframework.resources import ModelResource
class TestModelCreation(TestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
self.req = RequestFactory()
def test_creation(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
form_data = {'name': 'foo'}
request = self.req.post('/groups', data=form_data)
mixin = CreateModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
response = mixin.post(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
form_data = {'username': 'bar', 'password': 'baz', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)

View File

@ -0,0 +1,59 @@
from django.conf.urls.defaults import patterns, url
from django.test import TestCase
from django.forms import ModelForm
from django.contrib.auth.models import Group, User
from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView
class GroupResource(ModelResource):
model = Group
class UserForm(ModelForm):
class Meta:
model = User
exclude = ('last_login', 'date_joined')
class UserResource(ModelResource):
model = User
form = UserForm
urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
)
class ModelViewTests(TestCase):
"""Test the model views djangorestframework provides"""
urls = 'djangorestframework.tests.modelviews'
def test_creation(self):
"""Ensure that a model object can be created"""
self.assertEqual(0, Group.objects.count())
response = self.client.post('/groups/', {'name': 'foo'})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, Group.objects.count())
self.assertEqual('foo', Group.objects.all()[0].name)
def test_creation_with_m2m_relation(self):
"""Ensure that a model object with a m2m relation can be created"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, User.objects.count())
user = User.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual('baz', user.password)
self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)