This commit is contained in:
GitHub Merge Button 2011-07-11 12:33:49 -07:00
commit 2b90906293
5 changed files with 127 additions and 2 deletions

View File

@ -524,7 +524,9 @@ class CreateModelMixin(object):
for field in model._meta.many_to_many: for field in model._meta.many_to_many:
if content.has_key(field.name): if content.has_key(field.name):
m2m_data[field.name] = content[field.name] m2m_data[field.name] = (
field.m2m_reverse_field_name(), content[field.name]
)
del content[field.name] del content[field.name]
all_kw_args = dict(content.items() + kwargs.items()) all_kw_args = dict(content.items() + kwargs.items())
@ -536,7 +538,17 @@ class CreateModelMixin(object):
instance.save() instance.save()
for fieldname in m2m_data: for fieldname in m2m_data:
getattr(instance, fieldname).add(*m2m_data[fieldname]) manager = getattr(instance, fieldname)
if hasattr(manager, 'add'):
manager.add(*m2m_data[fieldname][1])
else:
data = {}
data[manager.source_field_name] = instance
for related_item in m2m_data[fieldname][1]:
data[m2m_data[fieldname][0]] = related_item
manager.through(**data).save()
headers = {} headers = {}
if hasattr(instance, 'get_absolute_url'): if hasattr(instance, 'get_absolute_url'):

View File

@ -95,6 +95,7 @@ INSTALLED_APPS = (
# Uncomment the next line to enable admin documentation: # Uncomment the next line to enable admin documentation:
# 'django.contrib.admindocs', # 'django.contrib.admindocs',
'djangorestframework', 'djangorestframework',
'djangorestframework.tests',
) )
# OAuth support is optional, so we only test oauth if it's installed. # OAuth support is optional, so we only test oauth if it's installed.

View File

@ -5,6 +5,7 @@ from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from djangorestframework.mixins import CreateModelMixin from djangorestframework.mixins import CreateModelMixin
from djangorestframework.resources import ModelResource from djangorestframework.resources import ModelResource
from djangorestframework.tests.models import CustomUser
class TestModelCreation(TestCase): class TestModelCreation(TestCase):
@ -53,4 +54,60 @@ class TestModelCreation(TestCase):
self.assertEquals(1, response.cleaned_content.groups.count()) self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name) self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
def url(self, instance):
return "/customusers/%i" % instance.id
form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(1, CustomUser.objects.count())
self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
group2 = Group(name='foo2')
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
mixin = CreateModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.post(request)
self.assertEquals(3, CustomUser.objects.count())
self.assertEquals(2, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)

View File

@ -0,0 +1,28 @@
from django.db import models
from django.contrib.auth.models import Group
class CustomUser(models.Model):
"""
A custom user model, which uses a 'through' table for the foreign key
"""
username = models.CharField(max_length=255, unique=True)
groups = models.ManyToManyField(
to=Group, blank=True, null=True, through='UserGroupMap'
)
@models.permalink
def get_absolute_url(self):
return ('custom_user', (), {
'pk': self.id
})
class UserGroupMap(models.Model):
user = models.ForeignKey(to=CustomUser)
group = models.ForeignKey(to=Group)
@models.permalink
def get_absolute_url(self):
return ('user_group_map', (), {
'pk': self.id
})

View File

@ -4,6 +4,7 @@ from django.forms import ModelForm
from django.contrib.auth.models import Group, User from django.contrib.auth.models import Group, User
from djangorestframework.resources import ModelResource from djangorestframework.resources import ModelResource
from djangorestframework.views import ListOrCreateModelView, InstanceModelView from djangorestframework.views import ListOrCreateModelView, InstanceModelView
from djangorestframework.tests.models import CustomUser
class GroupResource(ModelResource): class GroupResource(ModelResource):
model = Group model = Group
@ -17,9 +18,14 @@ class UserResource(ModelResource):
model = User model = User
form = UserForm form = UserForm
class CustomUserResource(ModelResource):
model = CustomUser
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'), url(r'^users/$', ListOrCreateModelView.as_view(resource=UserResource), name='users'),
url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)), url(r'^users/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=UserResource)),
url(r'^customusers/$', ListOrCreateModelView.as_view(resource=CustomUserResource), name='customusers'),
url(r'^customusers/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=CustomUserResource)),
url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'), url(r'^groups/$', ListOrCreateModelView.as_view(resource=GroupResource), name='groups'),
url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)), url(r'^groups/(?P<id>[0-9]+)/$', InstanceModelView.as_view(resource=GroupResource)),
) )
@ -57,3 +63,24 @@ class ModelViewTests(TestCase):
group = user.groups.all()[0] group = user.groups.all()[0]
self.assertEqual('foo', group.name) self.assertEqual('foo', group.name)
def test_creation_with_m2m_relation_through(self):
"""
Ensure that a model object with a m2m relation can be created where that
relation uses a through table
"""
group = Group(name='foo')
group.save()
self.assertEqual(0, User.objects.count())
response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]})
self.assertEqual(response.status_code, 201)
self.assertEqual(1, CustomUser.objects.count())
user = CustomUser.objects.all()[0]
self.assertEqual('bar', user.username)
self.assertEqual(1, user.groups.count())
group = user.groups.all()[0]
self.assertEqual('foo', group.name)