diff --git a/graphene_django/forms/mutation.py b/graphene_django/forms/mutation.py index 49fabb8..876e76a 100644 --- a/graphene_django/forms/mutation.py +++ b/graphene_django/forms/mutation.py @@ -27,13 +27,13 @@ def fields_for_form(form, only_fields, exclude_fields): return fields -class BaseFormMutation(ClientIDMutation): +class BaseDjangoFormMutation(ClientIDMutation): class Meta: abstract = True @classmethod def mutate_and_get_payload(cls, root, info, **input): - form = cls._meta.form_class(data=input) + form = cls.get_form(root, info, **input) if form.is_valid(): return cls.perform_mutate(form, info) @@ -45,12 +45,28 @@ class BaseFormMutation(ClientIDMutation): return cls(errors=errors) + @classmethod + def get_form(cls, root, info, **input): + form_kwargs = cls.get_form_kwargs(root, info, **input) + return cls._meta.form_class(**form_kwargs) -class FormMutationOptions(MutationOptions): + @classmethod + def get_form_kwargs(cls, root, info, **input): + kwargs = {'data': input} + + pk = input.pop('id', None) + if pk: + instance = cls._meta.model._default_manager.get(pk=pk) + kwargs['instance'] = instance + + return kwargs + + +class DjangoFormMutationOptions(MutationOptions): form_class = None -class FormMutation(BaseFormMutation): +class DjangoFormMutation(BaseDjangoFormMutation): class Meta: abstract = True @@ -67,7 +83,7 @@ class FormMutation(BaseFormMutation): input_fields = fields_for_form(form, only_fields, exclude_fields) output_fields = fields_for_form(form, only_fields, exclude_fields) - _meta = FormMutationOptions(cls) + _meta = DjangoFormMutationOptions(cls) _meta.form_class = form_class _meta.fields = yank_fields_from_attrs( output_fields, @@ -78,7 +94,7 @@ class FormMutation(BaseFormMutation): input_fields, _as=InputField, ) - super(FormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) + super(DjangoFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) @classmethod def perform_mutate(cls, form, info): @@ -86,12 +102,12 @@ class FormMutation(BaseFormMutation): return cls(errors=[]) -class ModelFormMutationOptions(FormMutationOptions): +class DjangoModelDjangoFormMutationOptions(DjangoFormMutationOptions): model = None return_field_name = None -class ModelFormMutation(BaseFormMutation): +class DjangoModelFormMutation(BaseDjangoFormMutation): class Meta: abstract = True @@ -102,13 +118,13 @@ class ModelFormMutation(BaseFormMutation): only_fields=(), exclude_fields=(), **options): if not form_class: - raise Exception('form_class is required for ModelFormMutation') + raise Exception('form_class is required for DjangoModelFormMutation') if not model: model = form_class._meta.model if not model: - raise Exception('model is required for ModelFormMutation') + raise Exception('model is required for DjangoModelFormMutation') form = form_class() input_fields = fields_for_form(form, only_fields, exclude_fields) @@ -119,7 +135,7 @@ class ModelFormMutation(BaseFormMutation): output_fields = OrderedDict() output_fields[return_field_name] = graphene.Field(model_type) - _meta = ModelFormMutationOptions(cls) + _meta = DjangoModelDjangoFormMutationOptions(cls) _meta.form_class = form_class _meta.model = model _meta.return_field_name = return_field_name @@ -132,7 +148,7 @@ class ModelFormMutation(BaseFormMutation): input_fields, _as=InputField, ) - super(ModelFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) + super(DjangoModelFormMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) @classmethod def perform_mutate(cls, form, info): diff --git a/graphene_django/forms/tests/test_mutation.py b/graphene_django/forms/tests/test_mutation.py index 084b8b0..5876405 100644 --- a/graphene_django/forms/tests/test_mutation.py +++ b/graphene_django/forms/tests/test_mutation.py @@ -3,7 +3,7 @@ from django.test import TestCase from py.test import raises from graphene_django.tests.models import Pet, Film -from ..mutation import FormMutation, ModelFormMutation +from ..mutation import DjangoFormMutation, DjangoModelFormMutation class MyForm(forms.Form): @@ -19,14 +19,14 @@ class PetForm(forms.ModelForm): def test_needs_form_class(): with raises(Exception) as exc: - class MyMutation(FormMutation): + class MyMutation(DjangoFormMutation): pass assert exc.value.args[0] == 'form_class is required for FormMutation' def test_has_output_fields(): - class MyMutation(FormMutation): + class MyMutation(DjangoFormMutation): class Meta: form_class = MyForm @@ -34,7 +34,7 @@ def test_has_output_fields(): def test_has_input_fields(): - class MyMutation(FormMutation): + class MyMutation(DjangoFormMutation): class Meta: form_class = MyForm @@ -44,7 +44,7 @@ def test_has_input_fields(): class ModelFormMutationTests(TestCase): def test_default_meta_fields(self): - class PetMutation(ModelFormMutation): + class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm @@ -53,7 +53,7 @@ class ModelFormMutationTests(TestCase): self.assertIn('pet', PetMutation._meta.fields) def test_custom_return_field_name(self): - class PetMutation(ModelFormMutation): + class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm model = Film @@ -64,19 +64,33 @@ class ModelFormMutationTests(TestCase): self.assertIn('animal', PetMutation._meta.fields) def test_model_form_mutation_mutate(self): - class PetMutation(ModelFormMutation): + class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm - result = PetMutation.mutate_and_get_payload(None, None, name='Fluffy') + pet = Pet.objects.create(name='Axel') + + result = PetMutation.mutate_and_get_payload(None, None, id=pet.pk, name='Mia') + + self.assertEqual(Pet.objects.count(), 1) + pet.refresh_from_db() + self.assertEqual(pet.name, 'Mia') + self.assertEqual(result.errors, []) + + def test_model_form_mutation_updates_existing_(self): + class PetMutation(DjangoModelFormMutation): + class Meta: + form_class = PetForm + + result = PetMutation.mutate_and_get_payload(None, None, name='Mia') self.assertEqual(Pet.objects.count(), 1) pet = Pet.objects.get() - self.assertEqual(pet.name, 'Fluffy') + self.assertEqual(pet.name, 'Mia') self.assertEqual(result.errors, []) def test_model_form_mutation_mutate_invalid_form(self): - class PetMutation(ModelFormMutation): + class PetMutation(DjangoModelFormMutation): class Meta: form_class = PetForm