diff --git a/docs/index.rst b/docs/index.rst index ccc6bd4..256da68 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -11,4 +11,5 @@ Contents: filtering authorization debug + rest-framework introspection diff --git a/docs/rest-framework.rst b/docs/rest-framework.rst new file mode 100644 index 0000000..5e5dd70 --- /dev/null +++ b/docs/rest-framework.rst @@ -0,0 +1,21 @@ +Integration with Django Rest Framework +====================================== + +You can re-use your Django Rest Framework serializer with +graphene django. + + +Mutation +-------- + +You can create a Mutation based on a serializer by using the +`SerializerMutation` base class: + +.. code:: python + + from graphene_django.rest_framework.mutation import SerializerMutation + + class MyAwesomeMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + diff --git a/graphene_django/rest_framework/__init__.py b/graphene_django/rest_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py new file mode 100644 index 0000000..e5b3be0 --- /dev/null +++ b/graphene_django/rest_framework/mutation.py @@ -0,0 +1,129 @@ +from collections import OrderedDict +from functools import partial + +import six +import graphene +from graphene.types import Argument, Field +from graphene.types.mutation import Mutation, MutationMeta +from graphene.types.objecttype import ( + ObjectTypeMeta, + merge, + yank_fields_from_attrs +) +from graphene.types.options import Options +from graphene.types.utils import get_field_as +from graphene.utils.is_base_type import is_base_type + +from .serializer_converter import ( + convert_serializer_to_input_type, + convert_serializer_field +) +from .types import ErrorType + + +class SerializerMutationOptions(Options): + def __init__(self, *args, **kwargs): + super().__init__(*args, serializer_class=None, **kwargs) + + +class SerializerMutationMeta(MutationMeta): + def __new__(cls, name, bases, attrs): + if not is_base_type(bases, SerializerMutationMeta): + return type.__new__(cls, name, bases, attrs) + + options = Options( + attrs.pop('Meta', None), + name=name, + description=attrs.pop('__doc__', None), + serializer_class=None, + local_fields=None, + only_fields=(), + exclude_fields=(), + interfaces=(), + registry=None + ) + + if not options.serializer_class: + raise Exception('Missing serializer_class') + + cls = ObjectTypeMeta.__new__( + cls, name, bases, dict(attrs, _meta=options) + ) + + serializer_fields = cls.fields_for_serializer(options) + options.serializer_fields = yank_fields_from_attrs( + serializer_fields, + _as=Field, + ) + + options.fields = merge( + options.interface_fields, options.serializer_fields, + options.base_fields, options.local_fields, + {'errors': get_field_as(cls.errors, Field)} + ) + + cls.Input = convert_serializer_to_input_type(options.serializer_class) + + cls.Field = partial( + Field, + cls, + resolver=cls.mutate, + input=Argument(cls.Input, required=True) + ) + + return cls + + @staticmethod + def fields_for_serializer(options): + serializer = options.serializer_class() + + only_fields = options.only_fields + + already_created_fields = { + name + for name, _ in options.local_fields.items() + } + + fields = OrderedDict() + for name, field in serializer.fields.items(): + is_not_in_only = only_fields and name not in only_fields + is_excluded = ( + name in options.exclude_fields or + name in already_created_fields + ) + + if is_not_in_only or is_excluded: + continue + + fields[name] = convert_serializer_field(field, is_input=False) + return fields + + +class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)): + errors = graphene.List( + ErrorType, + description='May contain more than one error for ' + 'same field.' + ) + + @classmethod + def mutate(cls, instance, args, request, info): + input = args.get('input') + + serializer = cls._meta.serializer_class(data=dict(input)) + + if serializer.is_valid(): + return cls.perform_mutate(serializer, info) + else: + errors = [ + ErrorType(field=key, messages=value) + for key, value in serializer.errors.items() + ] + + return cls(errors=errors) + + @classmethod + def perform_mutate(cls, serializer, info): + obj = serializer.save() + + return cls(errors=[], **obj) diff --git a/graphene_django/rest_framework/serializer_converter.py b/graphene_django/rest_framework/serializer_converter.py new file mode 100644 index 0000000..8b04d46 --- /dev/null +++ b/graphene_django/rest_framework/serializer_converter.py @@ -0,0 +1,124 @@ +from django.core.exceptions import ImproperlyConfigured +from rest_framework import serializers + +import graphene + +from ..registry import get_global_registry +from ..utils import import_single_dispatch +from .types import DictType + +singledispatch = import_single_dispatch() + + +def convert_serializer_to_input_type(serializer_class): + serializer = serializer_class() + + items = { + name: convert_serializer_field(field) + for name, field in serializer.fields.items() + } + + return type( + '{}Input'.format(serializer.__class__.__name__), + (graphene.InputObjectType, ), + items + ) + + +@singledispatch +def get_graphene_type_from_serializer_field(field): + raise ImproperlyConfigured( + "Don't know how to convert the serializer field %s (%s) " + "to Graphene type" % (field, field.__class__) + ) + + +def convert_serializer_field(field, is_input=True): + """ + Converts a django rest frameworks field to a graphql field + and marks the field as required if we are creating an input type + and the field itself is required + """ + + graphql_type = get_graphene_type_from_serializer_field(field) + + args = [] + kwargs = { + 'description': field.help_text, + 'required': is_input and field.required, + } + + # if it is a tuple or a list it means that we are returning + # the graphql type and the child type + if isinstance(graphql_type, (list, tuple)): + kwargs['of_type'] = graphql_type[1] + graphql_type = graphql_type[0] + + if isinstance(field, serializers.ModelSerializer): + if is_input: + graphql_type = convert_serializer_to_input_type(field.__class__) + else: + global_registry = get_global_registry() + field_model = field.Meta.model + args = [global_registry.get_type_for_model(field_model)] + + return graphql_type(*args, **kwargs) + + +@get_graphene_type_from_serializer_field.register(serializers.Field) +def convert_serializer_field_to_string(field): + return graphene.String + + +@get_graphene_type_from_serializer_field.register(serializers.ModelSerializer) +def convert_serializer_to_field(field): + return graphene.Field + + +@get_graphene_type_from_serializer_field.register(serializers.IntegerField) +def convert_serializer_field_to_int(field): + return graphene.Int + + +@get_graphene_type_from_serializer_field.register(serializers.BooleanField) +def convert_serializer_field_to_bool(field): + return graphene.Boolean + + +@get_graphene_type_from_serializer_field.register(serializers.FloatField) +@get_graphene_type_from_serializer_field.register(serializers.DecimalField) +def convert_serializer_field_to_float(field): + return graphene.Float + + +@get_graphene_type_from_serializer_field.register(serializers.DateTimeField) +@get_graphene_type_from_serializer_field.register(serializers.DateField) +def convert_serializer_field_to_date_time(field): + return graphene.types.datetime.DateTime + + +@get_graphene_type_from_serializer_field.register(serializers.TimeField) +def convert_serializer_field_to_time(field): + return graphene.types.datetime.Time + + +@get_graphene_type_from_serializer_field.register(serializers.ListField) +def convert_serializer_field_to_list(field, is_input=True): + child_type = get_graphene_type_from_serializer_field(field.child) + + return (graphene.List, child_type) + + +@get_graphene_type_from_serializer_field.register(serializers.DictField) +def convert_serializer_field_to_dict(field): + return DictType + + +@get_graphene_type_from_serializer_field.register(serializers.JSONField) +def convert_serializer_field_to_jsonstring(field): + return graphene.types.json.JSONString + + +@get_graphene_type_from_serializer_field.register(serializers.MultipleChoiceField) +def convert_serializer_field_to_list_of_string(field): + return (graphene.List, graphene.String) diff --git a/graphene_django/rest_framework/tests/__init__.py b/graphene_django/rest_framework/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/graphene_django/rest_framework/tests/test_field_converter.py b/graphene_django/rest_framework/tests/test_field_converter.py new file mode 100644 index 0000000..623cf58 --- /dev/null +++ b/graphene_django/rest_framework/tests/test_field_converter.py @@ -0,0 +1,162 @@ +import copy +from rest_framework import serializers +from py.test import raises + +import graphene + +from ..serializer_converter import convert_serializer_field +from ..types import DictType + + +def _get_type(rest_framework_field, is_input=True, **kwargs): + # prevents the following error: + # AssertionError: The `source` argument is not meaningful when applied to a `child=` field. + # Remove `source=` from the field declaration. + # since we are reusing the same child in when testing the required attribute + + if 'child' in kwargs: + kwargs['child'] = copy.deepcopy(kwargs['child']) + + field = rest_framework_field(**kwargs) + + return convert_serializer_field(field, is_input=is_input) + + +def assert_conversion(rest_framework_field, graphene_field, **kwargs): + graphene_type = _get_type(rest_framework_field, help_text='Custom Help Text', **kwargs) + assert isinstance(graphene_type, graphene_field) + + graphene_type_required = _get_type( + rest_framework_field, help_text='Custom Help Text', required=True, **kwargs + ) + assert isinstance(graphene_type_required, graphene_field) + + return graphene_type + + +def test_should_unknown_rest_framework_field_raise_exception(): + with raises(Exception) as excinfo: + convert_serializer_field(None) + assert 'Don\'t know how to convert the serializer field' in str(excinfo.value) + + +def test_should_char_convert_string(): + assert_conversion(serializers.CharField, graphene.String) + + +def test_should_email_convert_string(): + assert_conversion(serializers.EmailField, graphene.String) + + +def test_should_slug_convert_string(): + assert_conversion(serializers.SlugField, graphene.String) + + +def test_should_url_convert_string(): + assert_conversion(serializers.URLField, graphene.String) + + +def test_should_choice_convert_string(): + assert_conversion(serializers.ChoiceField, graphene.String, choices=[]) + + +def test_should_base_field_convert_string(): + assert_conversion(serializers.Field, graphene.String) + + +def test_should_regex_convert_string(): + assert_conversion(serializers.RegexField, graphene.String, regex='[0-9]+') + + +def test_should_uuid_convert_string(): + if hasattr(serializers, 'UUIDField'): + assert_conversion(serializers.UUIDField, graphene.String) + + +def test_should_model_convert_field(): + + class MyModelSerializer(serializers.ModelSerializer): + class Meta: + model = None + fields = '__all__' + + assert_conversion(MyModelSerializer, graphene.Field, is_input=False) + + +def test_should_date_time_convert_datetime(): + assert_conversion(serializers.DateTimeField, graphene.types.datetime.DateTime) + + +def test_should_date_convert_datetime(): + assert_conversion(serializers.DateField, graphene.types.datetime.DateTime) + + +def test_should_time_convert_time(): + assert_conversion(serializers.TimeField, graphene.types.datetime.Time) + + +def test_should_integer_convert_int(): + assert_conversion(serializers.IntegerField, graphene.Int) + + +def test_should_boolean_convert_boolean(): + assert_conversion(serializers.BooleanField, graphene.Boolean) + + +def test_should_float_convert_float(): + assert_conversion(serializers.FloatField, graphene.Float) + + +def test_should_decimal_convert_float(): + assert_conversion(serializers.DecimalField, graphene.Float, max_digits=4, decimal_places=2) + + +def test_should_list_convert_to_list(): + class StringListField(serializers.ListField): + child = serializers.CharField() + + field_a = assert_conversion( + serializers.ListField, + graphene.List, + child=serializers.IntegerField(min_value=0, max_value=100) + ) + + assert field_a.of_type == graphene.Int + + field_b = assert_conversion(StringListField, graphene.List) + + assert field_b.of_type == graphene.String + + +def test_should_dict_convert_dict(): + assert_conversion(serializers.DictField, DictType) + + +def test_should_duration_convert_string(): + assert_conversion(serializers.DurationField, graphene.String) + + +def test_should_file_convert_string(): + assert_conversion(serializers.FileField, graphene.String) + + +def test_should_filepath_convert_string(): + assert_conversion(serializers.FilePathField, graphene.String, path='/') + + +def test_should_ip_convert_string(): + assert_conversion(serializers.IPAddressField, graphene.String) + + +def test_should_image_convert_string(): + assert_conversion(serializers.ImageField, graphene.String) + + +def test_should_json_convert_jsonstring(): + assert_conversion(serializers.JSONField, graphene.types.json.JSONString) + + +def test_should_multiplechoicefield_convert_to_list_of_string(): + field = assert_conversion(serializers.MultipleChoiceField, graphene.List, choices=[1,2,3]) + + assert field.of_type == graphene.String diff --git a/graphene_django/rest_framework/tests/test_mutation.py b/graphene_django/rest_framework/tests/test_mutation.py new file mode 100644 index 0000000..5143f76 --- /dev/null +++ b/graphene_django/rest_framework/tests/test_mutation.py @@ -0,0 +1,70 @@ +from django.db import models +from graphene import Field +from graphene.types.inputobjecttype import InputObjectType +from py.test import raises +from rest_framework import serializers + +from ...types import DjangoObjectType +from ..mutation import SerializerMutation + + +class MyFakeModel(models.Model): + cool_name = models.CharField(max_length=50) + + +class MyModelSerializer(serializers.ModelSerializer): + class Meta: + model = MyFakeModel + fields = '__all__' + + +class MySerializer(serializers.Serializer): + text = serializers.CharField() + model = MyModelSerializer() + + +def test_needs_serializer_class(): + with raises(Exception) as exc: + class MyMutation(SerializerMutation): + pass + + assert exc.value.args[0] == 'Missing serializer_class' + + +def test_has_fields(): + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + assert 'text' in MyMutation._meta.fields + assert 'model' in MyMutation._meta.fields + assert 'errors' in MyMutation._meta.fields + + +def test_has_input_fields(): + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + assert 'text' in MyMutation.Input._meta.fields + assert 'model' in MyMutation.Input._meta.fields + + +def test_nested_model(): + + class MyFakeModelGrapheneType(DjangoObjectType): + class Meta: + model = MyFakeModel + + class MyMutation(SerializerMutation): + class Meta: + serializer_class = MySerializer + + model_field = MyMutation._meta.fields['model'] + assert isinstance(model_field, Field) + assert model_field.type == MyFakeModelGrapheneType + + model_input = MyMutation.Input._meta.fields['model'] + model_input_type = model_input._type.of_type + assert issubclass(model_input_type, InputObjectType) + assert 'cool_name' in model_input_type._meta.fields diff --git a/graphene_django/rest_framework/types.py b/graphene_django/rest_framework/types.py new file mode 100644 index 0000000..956dc43 --- /dev/null +++ b/graphene_django/rest_framework/types.py @@ -0,0 +1,12 @@ +import graphene +from graphene.types.unmountedtype import UnmountedType + + +class ErrorType(graphene.ObjectType): + field = graphene.String() + messages = graphene.List(graphene.String) + + +class DictType(UnmountedType): + key = graphene.String() + value = graphene.String() diff --git a/graphene_django/utils.py b/graphene_django/utils.py index 468dc4c..6fc5599 100644 --- a/graphene_django/utils.py +++ b/graphene_django/utils.py @@ -11,8 +11,11 @@ class LazyList(object): pass -import django_filters # noqa -DJANGO_FILTER_INSTALLED = True +try: + import django_filters # noqa + DJANGO_FILTER_INSTALLED = True +except ImportError: + DJANGO_FILTER_INSTALLED = False def get_reverse_fields(model, local_field_names): diff --git a/setup.py b/setup.py index 8a503c2..bd24009 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,10 @@ from setuptools import find_packages, setup +rest_framework_require = [ + 'djangorestframework==3.6.3', +] + + tests_require = [ 'pytest>=2.7.2', 'pytest-cov', @@ -8,7 +13,7 @@ tests_require = [ 'pytz', 'django-filter', 'pytest-django==2.9.1', -] +] + rest_framework_require setup( name='graphene-django', @@ -53,8 +58,10 @@ setup( 'pytest-runner', ], tests_require=tests_require, + rest_framework_require=rest_framework_require, extras_require={ 'test': tests_require, + 'rest_framework': rest_framework_require, }, include_package_data=True, zip_safe=False,