diff --git a/graphene_django/fields.py b/graphene_django/fields.py index c2a2a8f..f82e4b2 100644 --- a/graphene_django/fields.py +++ b/graphene_django/fields.py @@ -2,6 +2,8 @@ from functools import partial from django.db.models.query import QuerySet +from promise import Promise + from graphene.types import Field, List from graphene.relay import ConnectionField, PageInfo from graphql_relay.connection.arrayconnection import connection_from_list_slice @@ -60,30 +62,7 @@ class DjangoConnectionField(ConnectionField): return default_queryset & queryset @classmethod - def connection_resolver(cls, resolver, connection, default_manager, max_limit, - enforce_first_or_last, root, args, context, info): - first = args.get('first') - last = args.get('last') - - if enforce_first_or_last: - assert first or last, ( - 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' - ).format(info.field_name) - - if max_limit: - if first: - assert first <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['first'] = min(first, max_limit) - - if last: - assert last <= max_limit, ( - 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' - ).format(first, info.field_name, max_limit) - args['last'] = min(last, max_limit) - - iterable = resolver(root, args, context, info) + def resolve_connection(cls, connection, default_manager, args, iterable): if iterable is None: iterable = default_manager iterable = maybe_queryset(iterable) @@ -108,6 +87,38 @@ class DjangoConnectionField(ConnectionField): connection.length = _len return connection + @classmethod + def connection_resolver(cls, resolver, connection, default_manager, max_limit, + enforce_first_or_last, root, args, context, info): + first = args.get('first') + last = args.get('last') + + if enforce_first_or_last: + assert first or last, ( + 'You must provide a `first` or `last` value to properly paginate the `{}` connection.' + ).format(info.field_name) + + if max_limit: + if first: + assert first <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['first'] = min(first, max_limit) + + if last: + assert last <= max_limit, ( + 'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.' + ).format(first, info.field_name, max_limit) + args['last'] = min(last, max_limit) + + iterable = resolver(root, args, context, info) + on_resolve = partial(cls.resolve_connection, connection, default_manager, args) + + if Promise.is_thenable(iterable): + return Promise.resolve(iterable).then(on_resolve) + + return on_resolve(iterable) + def get_resolver(self, parent_resolver): return partial( self.connection_resolver, diff --git a/graphene_django/tests/test_query.py b/graphene_django/tests/test_query.py index c1deebb..1041f7e 100644 --- a/graphene_django/tests/test_query.py +++ b/graphene_django/tests/test_query.py @@ -545,3 +545,146 @@ def test_should_error_if_first_is_greater_than_max(): assert result.data == expected graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False + + +def test_should_query_promise_connectionfields(): + from promise import Promise + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + def resolve_all_reporters(self, *args, **kwargs): + return Promise.resolve([Reporter(id=1)]) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=' + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected + + +def test_should_query_dataloader_fields(): + from promise import Promise + from promise.dataloader import DataLoader + + def article_batch_load_fn(keys): + queryset = Article.objects.filter(reporter_id__in=keys) + return Promise.resolve([ + [article for article in queryset if article.reporter_id == id] + for id in keys + ]) + + article_loader = DataLoader(article_batch_load_fn) + + class ArticleType(DjangoObjectType): + + class Meta: + model = Article + interfaces = (Node, ) + + class ReporterType(DjangoObjectType): + + class Meta: + model = Reporter + interfaces = (Node, ) + + articles = DjangoConnectionField(ArticleType) + + def resolve_articles(self, *args, **kwargs): + return article_loader.load(self.id) + + class Query(graphene.ObjectType): + all_reporters = DjangoConnectionField(ReporterType) + + r = Reporter.objects.create( + first_name='John', + last_name='Doe', + email='johndoe@example.com', + a_choice=1 + ) + Article.objects.create( + headline='Article Node 1', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='es' + ) + Article.objects.create( + headline='Article Node 2', + pub_date=datetime.date.today(), + reporter=r, + editor=r, + lang='en' + ) + + schema = graphene.Schema(query=Query) + query = ''' + query ReporterPromiseConnectionQuery { + allReporters(first: 1) { + edges { + node { + id + articles(first: 2) { + edges { + node { + headline + } + } + } + } + } + } + } + ''' + + expected = { + 'allReporters': { + 'edges': [{ + 'node': { + 'id': 'UmVwb3J0ZXJUeXBlOjE=', + 'articles': { + 'edges': [{ + 'node': { + 'headline': 'Article Node 1', + } + }, { + 'node': { + 'headline': 'Article Node 2' + } + }] + } + } + }] + } + } + + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/setup.py b/setup.py index 9673d9b..8a503c2 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ setup( 'Django>=1.8.0', 'iso8601', 'singledispatch>=3.4.0.3', + 'promise>=2.0', ], setup_requires=[ 'pytest-runner',