From 3d58148f0311650330ce087fb9662f7cf951e215 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Fri, 28 Jul 2017 09:43:27 -0700 Subject: [PATCH] Code adapted to new resolver API --- graphene_django/debug/middleware.py | 5 +++-- graphene_django/debug/tests/test_query.py | 8 +++---- graphene_django/fields.py | 15 ++++++------- graphene_django/filter/fields.py | 12 +++++------ graphene_django/filter/tests/test_fields.py | 4 ++-- graphene_django/rest_framework/mutation.py | 10 ++++----- graphene_django/tests/schema.py | 8 +++---- graphene_django/tests/schema_view.py | 13 ++++++----- graphene_django/tests/test_query.py | 24 ++++++++++----------- graphene_django/tests/test_types.py | 2 +- graphene_django/types.py | 6 +++--- 11 files changed, 51 insertions(+), 56 deletions(-) diff --git a/graphene_django/debug/middleware.py b/graphene_django/debug/middleware.py index acd8524..2b11f7e 100644 --- a/graphene_django/debug/middleware.py +++ b/graphene_django/debug/middleware.py @@ -39,7 +39,8 @@ class DjangoDebugContext(object): class DjangoDebugMiddleware(object): - def resolve(self, next, root, args, context, info): + def resolve(self, next, root, info, **args): + context = info.context django_debug = getattr(context, 'django_debug', None) if not django_debug: if context is None: @@ -52,6 +53,6 @@ class DjangoDebugMiddleware(object): )) if info.schema.get_type('DjangoDebug') == info.return_type: return context.django_debug.get_debug_promise() - promise = next(root, args, context, info) + promise = next(root, info, **args) context.django_debug.add_promise(promise) return promise diff --git a/graphene_django/debug/tests/test_query.py b/graphene_django/debug/tests/test_query.py index 125f917..72747b2 100644 --- a/graphene_django/debug/tests/test_query.py +++ b/graphene_django/debug/tests/test_query.py @@ -33,7 +33,7 @@ def test_should_query_field(): reporter = graphene.Field(ReporterType) debug = graphene.Field(DjangoDebug, name='__debug') - def resolve_reporter(self, *args, **kwargs): + def resolve_reporter(self, info, **args): return Reporter.objects.first() query = ''' @@ -80,7 +80,7 @@ def test_should_query_list(): all_reporters = graphene.List(ReporterType) debug = graphene.Field(DjangoDebug, name='__debug') - def resolve_all_reporters(self, *args, **kwargs): + def resolve_all_reporters(self, info, **args): return Reporter.objects.all() query = ''' @@ -129,7 +129,7 @@ def test_should_query_connection(): all_reporters = DjangoConnectionField(ReporterType) debug = graphene.Field(DjangoDebug, name='__debug') - def resolve_all_reporters(self, *args, **kwargs): + def resolve_all_reporters(self, info, **args): return Reporter.objects.all() query = ''' @@ -185,7 +185,7 @@ def test_should_query_connectionfilter(): s = graphene.String(resolver=lambda *_: "S") debug = graphene.Field(DjangoDebug, name='__debug') - def resolve_all_reporters(self, *args, **kwargs): + def resolve_all_reporters(self, info, **args): return Reporter.objects.all() query = ''' diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 7cf90d0..aa7f124 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -4,7 +4,6 @@ from django.db.models.query import QuerySet from promise import Promise -from graphene import final_resolver from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -23,11 +22,11 @@ class DjangoListField(Field): return self.type.of_type._meta.node._meta.model @staticmethod - def list_resolver(resolver, root, args, context, info): - return maybe_queryset(resolver(root, args, context, info)) + def list_resolver(resolver, root, info, **args): + return maybe_queryset(resolver(root, info, **args)) def get_resolver(self, parent_resolver): - return final_resolver(partial(self.list_resolver, parent_resolver)) + return partial(self.list_resolver, parent_resolver) class DjangoConnectionField(ConnectionField): @@ -98,7 +97,7 @@ class DjangoConnectionField(ConnectionField): @classmethod def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, args, context, info): + enforce_first_or_last, root, info, **args): first = args.get('first') last = args.get('last') @@ -120,7 +119,7 @@ class DjangoConnectionField(ConnectionField): ).format(first, info.field_name, max_limit) args['last'] = min(last, max_limit) - iterable = resolver(root, args, context, info) + iterable = resolver(root, info, **args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args) if Promise.is_thenable(iterable): @@ -129,11 +128,11 @@ class DjangoConnectionField(ConnectionField): return on_resolve(iterable) def get_resolver(self, parent_resolver): - return final_resolver(partial( + return partial( self.connection_resolver, parent_resolver, self.type, self.get_manager(), self.max_limit, self.enforce_first_or_last - )) + ) diff --git a/graphene_django/filter/fields.py b/graphene_django/filter/fields.py index c98e10a..a80d8d7 100644 --- a/graphene_django/filter/fields.py +++ b/graphene_django/filter/fields.py @@ -1,7 +1,6 @@ from collections import OrderedDict from functools import partial -from graphene import final_resolver from graphene.types.argument import to_arguments from ..fields import DjangoConnectionField from .utils import get_filtering_args_from_filterset, get_filterset_class @@ -69,7 +68,7 @@ class DjangoFilterConnectionField(DjangoConnectionField): @classmethod def connection_resolver(cls, resolver, connection, default_manager, max_limit, enforce_first_or_last, filterset_class, filtering_args, - root, args, context, info): + root, info, **args): filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} qs = filterset_class( data=filter_kwargs, @@ -83,13 +82,12 @@ class DjangoFilterConnectionField(DjangoConnectionField): max_limit, enforce_first_or_last, root, - args, - context, - info + info, + **args ) def get_resolver(self, parent_resolver): - return final_resolver(partial( + return partial( self.connection_resolver, parent_resolver, self.type, @@ -98,4 +96,4 @@ class DjangoFilterConnectionField(DjangoConnectionField): self.enforce_first_or_last, self.filterset_class, self.filtering_args - )) + ) diff --git a/graphene_django/filter/tests/test_fields.py b/graphene_django/filter/tests/test_fields.py index 3079c3a..9a0ba21 100644 --- a/graphene_django/filter/tests/test_fields.py +++ b/graphene_django/filter/tests/test_fields.py @@ -399,7 +399,7 @@ def test_should_query_filter_node_limit(): filterset_class=ReporterFilter ) - def resolve_all_reporters(self, **args): + def resolve_all_reporters(self, info, **args): return Reporter.objects.order_by('a_choice') Reporter.objects.create( @@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises(): filterset_class=ReporterFilter ) - def resolve_all_reporters(self, **args): + def resolve_all_reporters(self, info, **args): return Reporter.objects.order_by('a_choice')[:2] Reporter.objects.create( diff --git a/graphene_django/rest_framework/mutation.py b/graphene_django/rest_framework/mutation.py index 83499d4..beaaa49 100644 --- a/graphene_django/rest_framework/mutation.py +++ b/graphene_django/rest_framework/mutation.py @@ -1,7 +1,6 @@ from collections import OrderedDict import graphene -from graphene import annotate, Context, ResolveInfo from graphene.types import Field, InputField from graphene.types.mutation import MutationOptions from graphene.relay.mutation import ClientIDMutation @@ -68,12 +67,11 @@ class SerializerMutation(ClientIDMutation): super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) @classmethod - @annotate(context=Context, info=ResolveInfo) - def mutate_and_get_payload(cls, root, input, context, info): - serializer = cls._meta.serializer_class(data=dict(input)) + def mutate_and_get_payload(cls, root, info, **input): + serializer = cls._meta.serializer_class(data=input) if serializer.is_valid(): - return cls.perform_mutate(serializer, context, info) + return cls.perform_mutate(serializer, info) else: errors = [ ErrorType(field=key, messages=value) @@ -83,6 +81,6 @@ class SerializerMutation(ClientIDMutation): return cls(errors=errors) @classmethod - def perform_mutate(cls, serializer, context, info): + def perform_mutate(cls, serializer, info): obj = serializer.save() return cls(**obj) diff --git a/graphene_django/tests/schema.py b/graphene_django/tests/schema.py index 6aa8f28..3134604 100644 --- a/graphene_django/tests/schema.py +++ b/graphene_django/tests/schema.py @@ -11,7 +11,7 @@ class Character(DjangoObjectType): model = Reporter interfaces = (relay.Node, ) - def get_node(self, id, context, info): + def get_node(self, info, id): pass @@ -22,17 +22,17 @@ class Human(DjangoObjectType): model = Article interfaces = (relay.Node, ) - def resolve_raises(self): + def resolve_raises(self, info): raise Exception("This field should raise exception") - def get_node(self, id): + def get_node(self, info, id): pass class Query(graphene.ObjectType): human = graphene.Field(Human) - def resolve_human(self): + def resolve_human(self, info): return Human() diff --git a/graphene_django/tests/schema_view.py b/graphene_django/tests/schema_view.py index 8380407..c750433 100644 --- a/graphene_django/tests/schema_view.py +++ b/graphene_django/tests/schema_view.py @@ -1,5 +1,5 @@ import graphene -from graphene import ObjectType, Schema, annotate, Context +from graphene import ObjectType, Schema class QueryRoot(ObjectType): @@ -8,21 +8,20 @@ class QueryRoot(ObjectType): request = graphene.String(required=True) test = graphene.String(who=graphene.String()) - def resolve_thrower(self): + def resolve_thrower(self, info): raise Exception("Throws!") - @annotate(request=Context) - def resolve_request(self, request): - return request.GET.get('q') + def resolve_request(self, info): + return info.context.GET.get('q') - def resolve_test(self, who=None): + def resolve_test(self, info, who=None): return 'Hello %s' % (who or 'World') class MutationRoot(ObjectType): write_test = graphene.Field(QueryRoot) - def resolve_write_test(self): + def resolve_write_test(self, info): return QueryRoot() diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index 7b2b46b..ec8c64c 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self): + def resolve_reporter(self, info): return SimpleLazyObject(lambda: Reporter(id=1)) schema = graphene.Schema(query=Query) @@ -75,7 +75,7 @@ def test_should_query_well(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - def resolve_reporter(self): + def resolve_reporter(self, info): return Reporter(first_name='ABA', last_name='X') query = ''' @@ -119,7 +119,7 @@ def test_should_query_postgres_fields(): class Query(graphene.ObjectType): event = graphene.Field(EventType) - def resolve_event(self): + def resolve_event(self, info): return Event( ages=(0, 10), data={'angry_babies': True}, @@ -162,10 +162,10 @@ def test_should_node(): interfaces = (Node, ) @classmethod - def get_node(cls, id, context, info): + def get_node(cls, info, id): return Reporter(id=2, first_name='Cookie Monster') - def resolve_articles(self, **args): + def resolve_articles(self, info, **args): return [Article(headline='Hi!')] class ArticleNode(DjangoObjectType): @@ -175,7 +175,7 @@ def test_should_node(): interfaces = (Node, ) @classmethod - def get_node(cls, id, context, info): + def get_node(cls, info, id): return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)) class Query(graphene.ObjectType): @@ -183,7 +183,7 @@ def test_should_node(): reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - def resolve_reporter(self): + def resolve_reporter(self, info): return Reporter(id=1, first_name='ABA', last_name='X') query = ''' @@ -250,7 +250,7 @@ def test_should_query_connectionfields(): class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - def resolve_all_reporters(self, **args): + def resolve_all_reporters(self, info, **args): return [Reporter(id=1)] schema = graphene.Schema(query=Query) @@ -308,10 +308,10 @@ def test_should_keep_annotations(): all_reporters = DjangoConnectionField(ReporterType) all_articles = DjangoConnectionField(ArticleType) - def resolve_all_reporters(self, **args): + def resolve_all_reporters(self, info, **args): return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') - def resolve_all_articles(self, **args): + def resolve_all_articles(self, info, **args): return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') schema = graphene.Schema(query=Query) @@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields(): class Query(graphene.ObjectType): all_reporters = DjangoConnectionField(ReporterType) - def resolve_all_reporters(self, **args): + def resolve_all_reporters(self, info, **args): return Promise.resolve([Reporter(id=1)]) schema = graphene.Schema(query=Query) @@ -677,7 +677,7 @@ def test_should_query_dataloader_fields(): articles = DjangoConnectionField(ArticleType) - def resolve_articles(self, **args): + def resolve_articles(self, info, **args): return article_loader.load(self.id) class Query(graphene.ObjectType): diff --git a/graphene_django/tests/test_types.py b/graphene_django/tests/test_types.py index 0ae12c0..f0185d4 100644 --- a/graphene_django/tests/test_types.py +++ b/graphene_django/tests/test_types.py @@ -38,7 +38,7 @@ def test_django_interface(): @patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) def test_django_get_node(get): - article = Article.get_node(1, None, None) + article = Article.get_node(None, 1) get.assert_called_with(pk=1) assert article.id == 1 diff --git a/graphene_django/types.py b/graphene_django/types.py index 407f548..aeef7a6 100644 --- a/graphene_django/types.py +++ b/graphene_django/types.py @@ -90,11 +90,11 @@ class DjangoObjectType(ObjectType): if not skip_registry: registry.register(cls) - def resolve_id(self): + def resolve_id(self, info): return self.pk @classmethod - def is_type_of(cls, root, context, info): + def is_type_of(cls, root, info): if isinstance(root, SimpleLazyObject): root._setup() root = root._wrapped @@ -108,7 +108,7 @@ class DjangoObjectType(ObjectType): return model == cls._meta.model @classmethod - def get_node(cls, id, context, info): + def get_node(cls, info, id): try: return cls._meta.model.objects.get(pk=id) except cls._meta.model.DoesNotExist: