From 4350582c526765e0bba7abf352c4b510db1be495 Mon Sep 17 00:00:00 2001 From: "arianon@openmailbox.org" Date: Fri, 19 May 2017 19:12:28 -0400 Subject: [PATCH 1/2] Support Connections created from Promises --- graphene_django/fields.py | 3 + graphene_django/tests/test_query.py | 143 ++++++++++++++++++++++++++++ setup.py | 1 + 3 files changed, 147 insertions(+) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index c2a2a8f..2125d16 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -2,6 +2,8 @@ from functools import partial from django.db.models.query import QuerySet +from promise import Promise + from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -84,6 +86,7 @@ class DjangoConnectionField(ConnectionField): args['last'] = min(last, max_limit) iterable = resolver(root, args, context, info) + iterable = Promise.resolve(iterable).get() if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index c1deebb..1041f7e 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -545,3 +545,146 @@ def test_should_error_if_first_is_greater_than_max(): assert result.data == expected graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False + + +def test_should_query_promise_connectionfields(): + from promise import Promise + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + def resolve_all_reporters(self, *args, **kwargs): + return Promise.resolve([Reporter(id=1)]) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=' + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_query_dataloader_fields(): + from promise import Promise + from promise.dataloader import DataLoader + + def article_batch_load_fn(keys): + queryset = Article.objects.filter(reporter_id__in=keys) + return Promise.resolve([ + [article for article in queryset if article.reporter_id == id] + for id in keys + ]) + + article_loader = DataLoader(article_batch_load_fn) + + class ArticleType(DjangoObjectType): + + class Meta: + model = Article + interfaces = (Node, ) + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + articles = DjangoConnectionField(ArticleType) + + def resolve_articles(self, *args, **kwargs): + return article_loader.load(self.id) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + Article.objects.create( + headline='Article Node 2', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='en' + ) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=', + 'articles': { + 'edges': [{ + 'node': { + 'headline': 'Article Node 1', + } + }, { + 'node': { + 'headline': 'Article Node 2' + } + }] + } + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/setup.py b/setup.py index 2d2e578..0be5297 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ setup( 'Django>=1.6.0', 'iso8601', 'singledispatch>=3.4.0.3', + 'promise>=2.0', ], setup_requires=[ 'pytest-runner', From bfcac1d48c61ff742eda964f504bf55924cb11c3 Mon Sep 17 00:00:00 2001 From: "arianon@openmailbox.org" Date: Fri, 19 May 2017 19:33:00 -0400 Subject: [PATCH 2/2] Use Promise.then instead of Promise.get on DjangoConnectionField --- graphene_django/fields.py | 58 ++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/graphene_django/fields.py b/graphene_django/fields.py index 2125d16..f82e4b2 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -62,31 +62,7 @@ class DjangoConnectionField(ConnectionField): return default_queryset & queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, args, context, info): - first = args.get('first') - last = args.get('last') - - if enforce_first_or_last: - assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' - ).format(info.field_name) - - if max_limit: - if first: - assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) - - if last: - assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['last'] = min(last, max_limit) - - iterable = resolver(root, args, context, info) - iterable = Promise.resolve(iterable).get() + def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -111,6 +87,38 @@ class DjangoConnectionField(ConnectionField): connection.length = _len return connection + @classmethod + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, args, context, info): + first = args.get('first') + last = args.get('last') + + if enforce_first_or_last: + assert first or last, ( + 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + ).format(info.field_name) + + if max_limit: + if first: + assert first <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['first'] = min(first, max_limit) + + if last: + assert last <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['last'] = min(last, max_limit) + + iterable = resolver(root, args, context, info) + on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) + def get_resolver(self, parent_resolver): return partial( self.connection_resolver,