diff --git a/docs/rest-framework.rst b/docs/rest-framework.rst index 5e5dd70..028b42a 100644 --- a/docs/rest-framework.rst +++ b/docs/rest-framework.rst @@ -19,3 +19,50 @@ You can create a Mutation based on a serializer by using the class Meta: serializer_class = MySerializer +Create/Update Operations +--------------------- + +By default ModelSerializers accept create and update operations. To +customize this use the `model_operations` attribute. The update +operation looks up models by the primary key by default. You can +customize the look up with the lookup attribute. + +Other default attributes: + +`partial = False`: Accept updates without all the input fields. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ['create', 'update'] + lookup_field = 'id' + +Overriding Update Queries +------------------------- + +Use the method `get_serializer_kwargs` to override how +updates are applied. + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class AwesomeModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + if 'id' in input: + instance = Post.objects.filter(id=input['id'], owner=info.context.user).first() + if instance: + return {'instance': instance, 'data': input, 'partial': True} + + else: + raise http.Http404 + + return {'data': input, 'partial': True} diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index a776eab..a694553 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,5 +1,7 @@ from collections import OrderedDict +from django.shortcuts import get_object_or_404 + import graphene from graphene.types import Field, InputField from graphene.types.mutation import MutationOptions @@ -15,6 +17,9 @@ from .types import ErrorType class SerializerMutationOptions(MutationOptions): + lookup_field = None + model_class = None + model_operations = ['create', 'update'] serializer_class = None @@ -44,18 +49,34 @@ class SerializerMutation(ClientIDMutation): ) @classmethod - def __init_subclass_with_meta__(cls, serializer_class=None, + def __init_subclass_with_meta__(cls, lookup_field=None, + serializer_class=None, model_class=None, + model_operations=['create', 'update'], only_fields=(), exclude_fields=(), **options): if not serializer_class: raise Exception('serializer_class is required for the SerializerMutation') + if 'update' not in model_operations and 'create' not in model_operations: + raise Exception('model_operations must contain "create" and/or "update"') + serializer = serializer_class() + if model_class is None: + serializer_meta = getattr(serializer_class, 'Meta', None) + if serializer_meta: + model_class = getattr(serializer_meta, 'model', None) + + if lookup_field is None and model_class: + lookup_field = model_class._meta.pk.name + input_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=True) output_fields = fields_for_serializer(serializer, only_fields, exclude_fields, is_input=False) _meta = SerializerMutationOptions(cls) + _meta.lookup_field = lookup_field + _meta.model_operations = model_operations _meta.serializer_class = serializer_class + _meta.model_class = model_class _meta.fields = yank_fields_from_attrs( output_fields, _as=Field, @@ -67,9 +88,35 @@ class SerializerMutation(ClientIDMutation): ) super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) + @classmethod + def get_serializer_kwargs(cls, root, info, **input): + lookup_field = cls._meta.lookup_field + model_class = cls._meta.model_class + + if model_class: + if 'update' in cls._meta.model_operations and lookup_field in input: + instance = get_object_or_404(model_class, **{ + lookup_field: input[lookup_field]}) + elif 'create' in cls._meta.model_operations: + instance = None + else: + raise Exception( + 'Invalid update operation. Input parameter "{}" required.'.format( + lookup_field + )) + + return { + 'instance': instance, + 'data': input, + 'context': {'request': info.context} + } + + return {'data': input, 'context': {'request': info.context}} + @classmethod def mutate_and_get_payload(cls, root, info, **input): - serializer = cls._meta.serializer_class(data=input) + kwargs = cls.get_serializer_kwargs(root, info, **input) + serializer = cls._meta.serializer_class(**kwargs) if serializer.is_valid(): return cls.perform_mutate(serializer, info) diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py index 491192a..35acab7 100644 --- a/graphene_django/rest_framework/tests/test_mutation.py +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -1,6 +1,6 @@ import datetime -from graphene import Field +from graphene import Field, ResolveInfo from graphene.types.inputobjecttype import InputObjectType from py.test import raises from py.test import mark @@ -10,12 +10,29 @@ from ...types import DjangoObjectType from ..models import MyFakeModel from ..mutation import SerializerMutation +def mock_info(): + return ResolveInfo( + None, + None, + None, + None, + schema=None, + fragments=None, + root_value=None, + operation=None, + variable_values=None, + context=None + ) + class MyModelSerializer(serializers.ModelSerializer): class Meta: model = MyFakeModel fields = '__all__' +class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer class MySerializer(serializers.Serializer): text = serializers.CharField() @@ -92,7 +109,7 @@ def test_mutate_and_get_payload_success(): class Meta: serializer_class = MySerializer - result = MyMutation.mutate_and_get_payload(None, None, **{ + result = MyMutation.mutate_and_get_payload(None, mock_info(), **{ 'text': 'value', 'model': { 'cool_name': 'other_value' @@ -102,18 +119,38 @@ def test_mutate_and_get_payload_success(): @mark.django_db -def test_model_mutate_and_get_payload_success(): - class MyMutation(SerializerMutation): - class Meta: - serializer_class = MyModelSerializer - - result = MyMutation.mutate_and_get_payload(None, None, **{ +def test_model_add_mutate_and_get_payload_success(): + result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ 'cool_name': 'Narf', }) assert result.errors is None assert result.cool_name == 'Narf' assert isinstance(result.created, datetime.datetime) +@mark.django_db +def test_model_update_mutate_and_get_payload_success(): + instance = MyFakeModel.objects.create(cool_name="Narf") + result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{ + 'id': instance.id, + 'cool_name': 'New Narf', + }) + assert result.errors is None + assert result.cool_name == 'New Narf' + +@mark.django_db +def test_model_invalid_update_mutate_and_get_payload_success(): + class InvalidModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ['update'] + + with raises(Exception) as exc: + result = InvalidModelMutation.mutate_and_get_payload(None, mock_info(), **{ + 'cool_name': 'Narf', + }) + + assert '"id" required' in str(exc.value) + def test_mutate_and_get_payload_error(): class MyMutation(SerializerMutation): @@ -121,15 +158,19 @@ def test_mutate_and_get_payload_error(): serializer_class = MySerializer # missing required fields - result = MyMutation.mutate_and_get_payload(None, None, **{}) + result = MyMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 def test_model_mutate_and_get_payload_error(): - - class MyMutation(SerializerMutation): - class Meta: - serializer_class = MyModelSerializer - # missing required fields - result = MyMutation.mutate_and_get_payload(None, None, **{}) + result = MyModelMutation.mutate_and_get_payload(None, mock_info(), **{}) assert len(result.errors) > 0 + +def test_invalid_serializer_operations(): + with raises(Exception) as exc: + class MyModelMutation(SerializerMutation): + class Meta: + serializer_class = MyModelSerializer + model_operations = ['Add'] + + assert 'model_operations' in str(exc.value)