Code adapted to new resolver API

This commit is contained in:
Syrus Akbary 2017-07-28 09:43:27 -07:00
parent 64118790ff
commit 3d58148f03
11 changed files with 51 additions and 56 deletions

View File

@ -39,7 +39,8 @@ class DjangoDebugContext(object):
class DjangoDebugMiddleware(object): class DjangoDebugMiddleware(object):
def resolve(self, next, root, args, context, info): def resolve(self, next, root, info, **args):
context = info.context
django_debug = getattr(context, 'django_debug', None) django_debug = getattr(context, 'django_debug', None)
if not django_debug: if not django_debug:
if context is None: if context is None:
@ -52,6 +53,6 @@ class DjangoDebugMiddleware(object):
)) ))
if info.schema.get_type('DjangoDebug') == info.return_type: if info.schema.get_type('DjangoDebug') == info.return_type:
return context.django_debug.get_debug_promise() return context.django_debug.get_debug_promise()
promise = next(root, args, context, info) promise = next(root, info, **args)
context.django_debug.add_promise(promise) context.django_debug.add_promise(promise)
return promise return promise

View File

@ -33,7 +33,7 @@ def test_should_query_field():
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, info, **args):
return Reporter.objects.first() return Reporter.objects.first()
query = ''' query = '''
@ -80,7 +80,7 @@ def test_should_query_list():
all_reporters = graphene.List(ReporterType) all_reporters = graphene.List(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''
@ -129,7 +129,7 @@ def test_should_query_connection():
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''
@ -185,7 +185,7 @@ def test_should_query_connectionfilter():
s = graphene.String(resolver=lambda *_: "S") s = graphene.String(resolver=lambda *_: "S")
debug = graphene.Field(DjangoDebug, name='__debug') debug = graphene.Field(DjangoDebug, name='__debug')
def resolve_all_reporters(self, *args, **kwargs): def resolve_all_reporters(self, info, **args):
return Reporter.objects.all() return Reporter.objects.all()
query = ''' query = '''

View File

@ -4,7 +4,6 @@ from django.db.models.query import QuerySet
from promise import Promise from promise import Promise
from graphene import final_resolver
from graphene.types import Field, List from graphene.types import Field, List
from graphene.relay import ConnectionField, PageInfo from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphql_relay.connection.arrayconnection import connection_from_list_slice
@ -23,11 +22,11 @@ class DjangoListField(Field):
return self.type.of_type._meta.node._meta.model return self.type.of_type._meta.node._meta.model
@staticmethod @staticmethod
def list_resolver(resolver, root, args, context, info): def list_resolver(resolver, root, info, **args):
return maybe_queryset(resolver(root, args, context, info)) return maybe_queryset(resolver(root, info, **args))
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return final_resolver(partial(self.list_resolver, parent_resolver)) return partial(self.list_resolver, parent_resolver)
class DjangoConnectionField(ConnectionField): class DjangoConnectionField(ConnectionField):
@ -98,7 +97,7 @@ class DjangoConnectionField(ConnectionField):
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, args, context, info): enforce_first_or_last, root, info, **args):
first = args.get('first') first = args.get('first')
last = args.get('last') last = args.get('last')
@ -120,7 +119,7 @@ class DjangoConnectionField(ConnectionField):
).format(first, info.field_name, max_limit) ).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit) args['last'] = min(last, max_limit)
iterable = resolver(root, args, context, info) iterable = resolver(root, info, **args)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args) on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
if Promise.is_thenable(iterable): if Promise.is_thenable(iterable):
@ -129,11 +128,11 @@ class DjangoConnectionField(ConnectionField):
return on_resolve(iterable) return on_resolve(iterable)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return final_resolver(partial( return partial(
self.connection_resolver, self.connection_resolver,
parent_resolver, parent_resolver,
self.type, self.type,
self.get_manager(), self.get_manager(),
self.max_limit, self.max_limit,
self.enforce_first_or_last self.enforce_first_or_last
)) )

View File

@ -1,7 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from graphene import final_resolver
from graphene.types.argument import to_arguments from graphene.types.argument import to_arguments
from ..fields import DjangoConnectionField from ..fields import DjangoConnectionField
from .utils import get_filtering_args_from_filterset, get_filterset_class from .utils import get_filtering_args_from_filterset, get_filterset_class
@ -69,7 +68,7 @@ class DjangoFilterConnectionField(DjangoConnectionField):
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, filterset_class, filtering_args, enforce_first_or_last, filterset_class, filtering_args,
root, args, context, info): root, info, **args):
filter_kwargs = {k: v for k, v in args.items() if k in filtering_args} filter_kwargs = {k: v for k, v in args.items() if k in filtering_args}
qs = filterset_class( qs = filterset_class(
data=filter_kwargs, data=filter_kwargs,
@ -83,13 +82,12 @@ class DjangoFilterConnectionField(DjangoConnectionField):
max_limit, max_limit,
enforce_first_or_last, enforce_first_or_last,
root, root,
args, info,
context, **args
info
) )
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return final_resolver(partial( return partial(
self.connection_resolver, self.connection_resolver,
parent_resolver, parent_resolver,
self.type, self.type,
@ -98,4 +96,4 @@ class DjangoFilterConnectionField(DjangoConnectionField):
self.enforce_first_or_last, self.enforce_first_or_last,
self.filterset_class, self.filterset_class,
self.filtering_args self.filtering_args
)) )

View File

@ -399,7 +399,7 @@ def test_should_query_filter_node_limit():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice') return Reporter.objects.order_by('a_choice')
Reporter.objects.create( Reporter.objects.create(
@ -499,7 +499,7 @@ def test_should_query_filter_node_double_limit_raises():
filterset_class=ReporterFilter filterset_class=ReporterFilter
) )
def resolve_all_reporters(self, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.order_by('a_choice')[:2] return Reporter.objects.order_by('a_choice')[:2]
Reporter.objects.create( Reporter.objects.create(

View File

@ -1,7 +1,6 @@
from collections import OrderedDict from collections import OrderedDict
import graphene import graphene
from graphene import annotate, Context, ResolveInfo
from graphene.types import Field, InputField from graphene.types import Field, InputField
from graphene.types.mutation import MutationOptions from graphene.types.mutation import MutationOptions
from graphene.relay.mutation import ClientIDMutation from graphene.relay.mutation import ClientIDMutation
@ -68,12 +67,11 @@ class SerializerMutation(ClientIDMutation):
super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options) super(SerializerMutation, cls).__init_subclass_with_meta__(_meta=_meta, input_fields=input_fields, **options)
@classmethod @classmethod
@annotate(context=Context, info=ResolveInfo) def mutate_and_get_payload(cls, root, info, **input):
def mutate_and_get_payload(cls, root, input, context, info): serializer = cls._meta.serializer_class(data=input)
serializer = cls._meta.serializer_class(data=dict(input))
if serializer.is_valid(): if serializer.is_valid():
return cls.perform_mutate(serializer, context, info) return cls.perform_mutate(serializer, info)
else: else:
errors = [ errors = [
ErrorType(field=key, messages=value) ErrorType(field=key, messages=value)
@ -83,6 +81,6 @@ class SerializerMutation(ClientIDMutation):
return cls(errors=errors) return cls(errors=errors)
@classmethod @classmethod
def perform_mutate(cls, serializer, context, info): def perform_mutate(cls, serializer, info):
obj = serializer.save() obj = serializer.save()
return cls(**obj) return cls(**obj)

View File

@ -11,7 +11,7 @@ class Character(DjangoObjectType):
model = Reporter model = Reporter
interfaces = (relay.Node, ) interfaces = (relay.Node, )
def get_node(self, id, context, info): def get_node(self, info, id):
pass pass
@ -22,17 +22,17 @@ class Human(DjangoObjectType):
model = Article model = Article
interfaces = (relay.Node, ) interfaces = (relay.Node, )
def resolve_raises(self): def resolve_raises(self, info):
raise Exception("This field should raise exception") raise Exception("This field should raise exception")
def get_node(self, id): def get_node(self, info, id):
pass pass
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
human = graphene.Field(Human) human = graphene.Field(Human)
def resolve_human(self): def resolve_human(self, info):
return Human() return Human()

View File

@ -1,5 +1,5 @@
import graphene import graphene
from graphene import ObjectType, Schema, annotate, Context from graphene import ObjectType, Schema
class QueryRoot(ObjectType): class QueryRoot(ObjectType):
@ -8,21 +8,20 @@ class QueryRoot(ObjectType):
request = graphene.String(required=True) request = graphene.String(required=True)
test = graphene.String(who=graphene.String()) test = graphene.String(who=graphene.String())
def resolve_thrower(self): def resolve_thrower(self, info):
raise Exception("Throws!") raise Exception("Throws!")
@annotate(request=Context) def resolve_request(self, info):
def resolve_request(self, request): return info.context.GET.get('q')
return request.GET.get('q')
def resolve_test(self, who=None): def resolve_test(self, info, who=None):
return 'Hello %s' % (who or 'World') return 'Hello %s' % (who or 'World')
class MutationRoot(ObjectType): class MutationRoot(ObjectType):
write_test = graphene.Field(QueryRoot) write_test = graphene.Field(QueryRoot)
def resolve_write_test(self): def resolve_write_test(self, info):
return QueryRoot() return QueryRoot()

View File

@ -46,7 +46,7 @@ def test_should_query_simplelazy_objects():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self): def resolve_reporter(self, info):
return SimpleLazyObject(lambda: Reporter(id=1)) return SimpleLazyObject(lambda: Reporter(id=1))
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -75,7 +75,7 @@ def test_should_query_well():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
def resolve_reporter(self): def resolve_reporter(self, info):
return Reporter(first_name='ABA', last_name='X') return Reporter(first_name='ABA', last_name='X')
query = ''' query = '''
@ -119,7 +119,7 @@ def test_should_query_postgres_fields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
event = graphene.Field(EventType) event = graphene.Field(EventType)
def resolve_event(self): def resolve_event(self, info):
return Event( return Event(
ages=(0, 10), ages=(0, 10),
data={'angry_babies': True}, data={'angry_babies': True},
@ -162,10 +162,10 @@ def test_should_node():
interfaces = (Node, ) interfaces = (Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
return Reporter(id=2, first_name='Cookie Monster') return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, **args): def resolve_articles(self, info, **args):
return [Article(headline='Hi!')] return [Article(headline='Hi!')]
class ArticleNode(DjangoObjectType): class ArticleNode(DjangoObjectType):
@ -175,7 +175,7 @@ def test_should_node():
interfaces = (Node, ) interfaces = (Node, )
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11)) return Article(id=1, headline='Article node', pub_date=datetime.date(2002, 3, 11))
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
@ -183,7 +183,7 @@ def test_should_node():
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
def resolve_reporter(self): def resolve_reporter(self, info):
return Reporter(id=1, first_name='ABA', last_name='X') return Reporter(id=1, first_name='ABA', last_name='X')
query = ''' query = '''
@ -250,7 +250,7 @@ def test_should_query_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, **args): def resolve_all_reporters(self, info, **args):
return [Reporter(id=1)] return [Reporter(id=1)]
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -308,10 +308,10 @@ def test_should_keep_annotations():
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
all_articles = DjangoConnectionField(ArticleType) all_articles = DjangoConnectionField(ArticleType)
def resolve_all_reporters(self, **args): def resolve_all_reporters(self, info, **args):
return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c') return Reporter.objects.annotate(articles_c=Count('articles')).order_by('articles_c')
def resolve_all_articles(self, **args): def resolve_all_articles(self, info, **args):
return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg') return Article.objects.annotate(import_avg=Avg('importance')).order_by('import_avg')
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -618,7 +618,7 @@ def test_should_query_promise_connectionfields():
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType) all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, **args): def resolve_all_reporters(self, info, **args):
return Promise.resolve([Reporter(id=1)]) return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query) schema = graphene.Schema(query=Query)
@ -677,7 +677,7 @@ def test_should_query_dataloader_fields():
articles = DjangoConnectionField(ArticleType) articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, **args): def resolve_articles(self, info, **args):
return article_loader.load(self.id) return article_loader.load(self.id)
class Query(graphene.ObjectType): class Query(graphene.ObjectType):

View File

@ -38,7 +38,7 @@ def test_django_interface():
@patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1)) @patch('graphene_django.tests.models.Article.objects.get', return_value=Article(id=1))
def test_django_get_node(get): def test_django_get_node(get):
article = Article.get_node(1, None, None) article = Article.get_node(None, 1)
get.assert_called_with(pk=1) get.assert_called_with(pk=1)
assert article.id == 1 assert article.id == 1

View File

@ -90,11 +90,11 @@ class DjangoObjectType(ObjectType):
if not skip_registry: if not skip_registry:
registry.register(cls) registry.register(cls)
def resolve_id(self): def resolve_id(self, info):
return self.pk return self.pk
@classmethod @classmethod
def is_type_of(cls, root, context, info): def is_type_of(cls, root, info):
if isinstance(root, SimpleLazyObject): if isinstance(root, SimpleLazyObject):
root._setup() root._setup()
root = root._wrapped root = root._wrapped
@ -108,7 +108,7 @@ class DjangoObjectType(ObjectType):
return model == cls._meta.model return model == cls._meta.model
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, info, id):
try: try:
return cls._meta.model.objects.get(pk=id) return cls._meta.model.objects.get(pk=id)
except cls._meta.model.DoesNotExist: except cls._meta.model.DoesNotExist: