diff --git a/graphene_django/converter.py b/graphene_django/converter.py index b1a8837..47634f3 100644 --- a/graphene_django/converter.py +++ b/graphene_django/converter.py @@ -10,7 +10,7 @@ from graphene.utils.str_converters import to_camel_case, to_const from graphql import assert_valid_name from .compat import ArrayField, HStoreField, JSONField, RangeField -from .fields import get_connection_field, DjangoListField +from .fields import DjangoListField, DjangoConnectionField from .utils import get_related_model, import_single_dispatch singledispatch = import_single_dispatch() @@ -148,8 +148,16 @@ def convert_field_to_list_or_connection(field, registry=None): if not _type: return - if is_node(_type): - return get_connection_field(_type) + # If there is a connection, we should transform the field + # into a DjangoConnectionField + if _type._meta.connection: + # Use a DjangoFilterConnectionField if there are + # defined filter_fields in the DjangoObjectType Meta + if _type._meta.filter_fields: + from .filter.fields import DjangoFilterConnectionField + return DjangoFilterConnectionField(_type) + + return DjangoConnectionField(_type) return DjangoListField(_type) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index c6dcd26..65697a5 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -43,6 +43,13 @@ class DjangoConnectionField(ConnectionField): ) super(DjangoConnectionField, self).__init__(*args, **kwargs) + @property + def type(self): + from .types import DjangoObjectType + _type = super(ConnectionField, self).type + assert issubclass(_type, DjangoObjectType), "DjangoConnectionField only accepts DjangoObjectType types" + return _type._meta.connection + @property def node_type(self): return self.type._meta.node @@ -128,10 +135,3 @@ class DjangoConnectionField(ConnectionField): self.max_limit, self.enforce_first_or_last ) - - -def get_connection_field(*args, **kwargs): - if DJANGO_FILTER_INSTALLED: - from .filter.fields import DjangoFilterConnectionField - return DjangoFilterConnectionField(*args, **kwargs) - return DjangoConnectionField(*args, **kwargs) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 1b24ff2..114cf37 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -356,7 +356,7 @@ def test_recursive_filter_connection(): class ReporterFilterNode(DjangoObjectType): child_reporters = DjangoFilterConnectionField(lambda: ReporterFilterNode) - def resolve_child_reporters(self, args, context, info): + def resolve_child_reporters(self, **args): return [] class Meta: @@ -399,7 +399,7 @@ def test_should_query_filter_node_limit(): filterset_class=ReporterFilter ) - def resolve_all_reporters(self, args, context, info): + def resolve_all_reporters(self, **args): return Reporter.objects.order_by('a_choice') Reporter.objects.create( @@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises(): filterset_class=ReporterFilter ) - def resolve_all_reporters(self, args, context, info): + def resolve_all_reporters(self, **args): return Reporter.objects.order_by('a_choice')[:2] Reporter.objects.create( diff --git a/graphene_django/registry.py b/graphene_django/registry.py index 21fed12..b45c0c5 100644 --- a/graphene_django/registry.py +++ b/graphene_django/registry.py @@ -1,8 +1,10 @@ + class Registry(object): def __init__(self): self._registry = {} self._registry_models = {} + self._connection_types = {} def register(self, cls): from .types import DjangoObjectType diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index e5b3be0..070401b 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -3,16 +3,14 @@ from functools import partial import six import graphene -from graphene.types import Argument, Field -from graphene.types.mutation import Mutation, MutationMeta +from graphene import relay +from graphene.types import Argument, Field, InputField +from graphene.types.mutation import Mutation, MutationOptions from graphene.types.objecttype import ( - ObjectTypeMeta, - merge, yank_fields_from_attrs ) from graphene.types.options import Options from graphene.types.utils import get_field_as -from graphene.utils.is_base_type import is_base_type from .serializer_converter import ( convert_serializer_to_input_type, @@ -21,91 +19,53 @@ from .serializer_converter import ( from .types import ErrorType -class SerializerMutationOptions(Options): - def __init__(self, *args, **kwargs): - super().__init__(*args, serializer_class=None, **kwargs) +class SerializerMutationOptions(MutationOptions): + serializer_class = None -class SerializerMutationMeta(MutationMeta): - def __new__(cls, name, bases, attrs): - if not is_base_type(bases, SerializerMutationMeta): - return type.__new__(cls, name, bases, attrs) - - options = Options( - attrs.pop('Meta', None), - name=name, - description=attrs.pop('__doc__', None), - serializer_class=None, - local_fields=None, - only_fields=(), - exclude_fields=(), - interfaces=(), - registry=None +def fields_for_serializer(serializer, only_fields, exclude_fields): + fields = OrderedDict() + for name, field in serializer.fields.items(): + is_not_in_only = only_fields and name not in only_fields + is_excluded = ( + name in exclude_fields # or + # name in already_created_fields ) - if not options.serializer_class: - raise Exception('Missing serializer_class') + if is_not_in_only or is_excluded: + continue - cls = ObjectTypeMeta.__new__( - cls, name, bases, dict(attrs, _meta=options) - ) + fields[name] = convert_serializer_field(field, is_input=False) + return fields - serializer_fields = cls.fields_for_serializer(options) - options.serializer_fields = yank_fields_from_attrs( + +class SerializerMutation(relay.ClientIDMutation): + errors = graphene.List( + ErrorType, + description='May contain more than one error for same field.' + ) + + @classmethod + def __init_subclass_with_meta__(cls, serializer_class, + only_fields=(), exclude_fields=(), **options): + + if not serializer_class: + raise Exception('serializer_class is required for the SerializerMutation') + + serializer = serializer_class() + serializer_fields = fields_for_serializer(serializer, only_fields, exclude_fields) + + _meta = SerializerMutationOptions(cls) + _meta.fields = yank_fields_from_attrs( serializer_fields, _as=Field, ) - options.fields = merge( - options.interface_fields, options.serializer_fields, - options.base_fields, options.local_fields, - {'errors': get_field_as(cls.errors, Field)} + _meta.input_fields = yank_fields_from_attrs( + serializer_fields, + _as=InputField, ) - cls.Input = convert_serializer_to_input_type(options.serializer_class) - - cls.Field = partial( - Field, - cls, - resolver=cls.mutate, - input=Argument(cls.Input, required=True) - ) - - return cls - - @staticmethod - def fields_for_serializer(options): - serializer = options.serializer_class() - - only_fields = options.only_fields - - already_created_fields = { - name - for name, _ in options.local_fields.items() - } - - fields = OrderedDict() - for name, field in serializer.fields.items(): - is_not_in_only = only_fields and name not in only_fields - is_excluded = ( - name in options.exclude_fields or - name in already_created_fields - ) - - if is_not_in_only or is_excluded: - continue - - fields[name] = convert_serializer_field(field, is_input=False) - return fields - - -class SerializerMutation(six.with_metaclass(SerializerMutationMeta, Mutation)): - errors = graphene.List( - ErrorType, - description='May contain more than one error for ' - 'same field.' - ) - @classmethod def mutate(cls, instance, args, request, info): input = args.get('input') diff --git a/graphene_django/tests/schema.py b/graphene_django/tests/schema.py index 6f6f158..6aa8f28 100644 --- a/graphene_django/tests/schema.py +++ b/graphene_django/tests/schema.py @@ -22,7 +22,7 @@ class Human(DjangoObjectType): model = Article interfaces = (relay.Node, ) - def resolve_raises(self, *args): + def resolve_raises(self): raise Exception("This field should raise exception") def get_node(self, id): @@ -32,7 +32,7 @@ class Human(DjangoObjectType): class Query(graphene.ObjectType): human = graphene.Field(Human) - def resolve_human(self, args, context, info): + def resolve_human(self): return Human() diff --git a/graphene_django/tests/schema_view.py b/graphene_django/tests/schema_view.py index 429a9f8..8380407 100644 --- a/graphene_django/tests/schema_view.py +++ b/graphene_django/tests/schema_view.py @@ -1,5 +1,5 @@ import graphene -from graphene import ObjectType, Schema +from graphene import ObjectType, Schema, annotate, Context class QueryRoot(ObjectType): @@ -8,21 +8,21 @@ class QueryRoot(ObjectType): request = graphene.String(required=True) test = graphene.String(who=graphene.String()) - def resolve_thrower(self, args, context, info): + def resolve_thrower(self): raise Exception("Throws!") - def resolve_request(self, args, context, info): - request = context + @annotate(request=Context) + def resolve_request(self, request): return request.GET.get('q') - def resolve_test(self, args, context, info): - return 'Hello %s' % (args.get('who') or 'World') + def resolve_test(self, who=None): + return 'Hello %s' % (who or 'World') class MutationRoot(ObjectType): write_test = graphene.Field(QueryRoot) - def resolve_write_test(self, args, context, info): + def resolve_write_test(self): return QueryRoot() diff --git a/graphene_django/tests/test_converter.py b/graphene_django/tests/test_converter.py index 35caa15..e424177 100644 --- a/graphene_django/tests/test_converter.py +++ b/graphene_django/tests/test_converter.py @@ -224,7 +224,7 @@ def test_should_manytomany_convert_connectionorlist_connection(): assert isinstance(graphene_field, graphene.Dynamic) dynamic_field = graphene_field.get_type() assert isinstance(dynamic_field, ConnectionField) - assert dynamic_field.type == A.Connection + assert dynamic_field.type == A._meta.connection def test_should_manytoone_convert_connectionorlist(): diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 3ecd8ea..7b2b46b 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self, args, context, info): + def resolve_reporter(self): return SimpleLazyObject(lambda: Reporter(id=1)) schema = graphene.Schema(query=Query) @@ -75,7 +75,7 @@ def test_should_query_well(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self): return Reporter(first_name='ABA', last_name='X') query = ''' @@ -119,7 +119,7 @@ def test_should_query_postgres_fields(): class Query(graphene.ObjectType): event = graphene.Field(EventType) - def resolve_event(self, *args, **kwargs): + def resolve_event(self): return Event( ages=(0, 10), data={'angry_babies': True}, @@ -165,7 +165,7 @@ def test_should_node(): def get_node(cls, id, context, info): return Reporter(id=2, first_name='Cookie Monster') - def resolve_articles(self, *args, **kwargs): + def resolve_articles(self, **args): return [Article(headline='Hi!')] class ArticleNode(DjangoObjectType): @@ -183,7 +183,7 @@ def test_should_node(): reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self): return Reporter(id=1, first_name='ABA', last_name='X') query = ''' @@ -250,7 +250,7 @@ def test_should_query_connectionfields(): class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - def resolve_all_reporters(self, args, context, info): + def resolve_all_reporters(self, **args): return [Reporter(id=1)] schema = graphene.Schema(query=Query) @@ -308,10 +308,10 @@ def test_should_keep_annotations(): all_reporters = DjangoConnectionField(ReporterType) all_articles = DjangoConnectionField(ArticleType) - def resolve_all_reporters(self, args, context, info): + def resolve_all_reporters(self, **args): return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') - def resolve_all_articles(self, args, context, info): + def resolve_all_articles(self, **args): return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') schema = graphene.Schema(query=Query) @@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields(): class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - def resolve_all_reporters(self, *args, **kwargs): + def resolve_all_reporters(self, **args): return Promise.resolve([Reporter(id=1)]) schema = graphene.Schema(query=Query) @@ -673,10 +673,11 @@ def test_should_query_dataloader_fields(): class Meta: model = Reporter interfaces = (Node, ) + use_connection = True articles = DjangoConnectionField(ArticleType) - def resolve_articles(self, *args, **kwargs): + def resolve_articles(self, **args): return article_loader.load(self.id) class Query(graphene.ObjectType): diff --git a/setup.py b/setup.py index bd24009..cc84d4b 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import find_packages, setup rest_framework_require = [ - 'djangorestframework==3.6.3', + 'djangorestframework>=3.6.3', ] @@ -17,7 +17,7 @@ tests_require = [ setup( name='graphene-django', - version='1.3', + version='2.0.dev', description='Graphene Django integration', long_description=open('README.rst').read(), @@ -48,11 +48,11 @@ setup( install_requires=[ 'six>=1.10.0', - 'graphene>=1.4', + 'graphene>=2.0.dev', 'Django>=1.8.0', 'iso8601', 'singledispatch>=3.4.0.3', - 'promise>=2.0', + 'promise>=2.1.dev', ], setup_requires=[ 'pytest-runner',