Merge pull request #180 from arianon/master

Make DjangoConnectionField compatible with Promise-based iterables.
This commit is contained in:
Syrus Akbary 2017-06-24 15:44:51 -07:00 committed by GitHub
commit 7c52aa3c7f
3 changed files with 179 additions and 24 deletions

View File

@ -2,6 +2,8 @@ from functools import partial
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
from promise import Promise
from graphene.types import Field, List from graphene.types import Field, List
from graphene.relay import ConnectionField, PageInfo from graphene.relay import ConnectionField, PageInfo
from graphql_relay.connection.arrayconnection import connection_from_list_slice from graphql_relay.connection.arrayconnection import connection_from_list_slice
@ -60,30 +62,7 @@ class DjangoConnectionField(ConnectionField):
return default_queryset & queryset return default_queryset & queryset
@classmethod @classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit, def resolve_connection(cls, connection, default_manager, args, iterable):
enforce_first_or_last, root, args, context, info):
first = args.get('first')
last = args.get('last')
if enforce_first_or_last:
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit)
if last:
assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit)
iterable = resolver(root, args, context, info)
if iterable is None: if iterable is None:
iterable = default_manager iterable = default_manager
iterable = maybe_queryset(iterable) iterable = maybe_queryset(iterable)
@ -108,6 +87,38 @@ class DjangoConnectionField(ConnectionField):
connection.length = _len connection.length = _len
return connection return connection
@classmethod
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
enforce_first_or_last, root, args, context, info):
first = args.get('first')
last = args.get('last')
if enforce_first_or_last:
assert first or last, (
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
).format(info.field_name)
if max_limit:
if first:
assert first <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
).format(first, info.field_name, max_limit)
args['first'] = min(first, max_limit)
if last:
assert last <= max_limit, (
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
).format(first, info.field_name, max_limit)
args['last'] = min(last, max_limit)
iterable = resolver(root, args, context, info)
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
if Promise.is_thenable(iterable):
return Promise.resolve(iterable).then(on_resolve)
return on_resolve(iterable)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial( return partial(
self.connection_resolver, self.connection_resolver,

View File

@ -545,3 +545,146 @@ def test_should_error_if_first_is_greater_than_max():
assert result.data == expected assert result.data == expected
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
def test_should_query_promise_connectionfields():
from promise import Promise
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
def resolve_all_reporters(self, *args, **kwargs):
return Promise.resolve([Reporter(id=1)])
schema = graphene.Schema(query=Query)
query = '''
query ReporterPromiseConnectionQuery {
allReporters(first: 1) {
edges {
node {
id
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE='
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_query_dataloader_fields():
from promise import Promise
from promise.dataloader import DataLoader
def article_batch_load_fn(keys):
queryset = Article.objects.filter(reporter_id__in=keys)
return Promise.resolve([
[article for article in queryset if article.reporter_id == id]
for id in keys
])
article_loader = DataLoader(article_batch_load_fn)
class ArticleType(DjangoObjectType):
class Meta:
model = Article
interfaces = (Node, )
class ReporterType(DjangoObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
articles = DjangoConnectionField(ArticleType)
def resolve_articles(self, *args, **kwargs):
return article_loader.load(self.id)
class Query(graphene.ObjectType):
all_reporters = DjangoConnectionField(ReporterType)
r = Reporter.objects.create(
first_name='John',
last_name='Doe',
email='johndoe@example.com',
a_choice=1
)
Article.objects.create(
headline='Article Node 1',
pub_date=datetime.date.today(),
reporter=r,
editor=r,
lang='es'
)
Article.objects.create(
headline='Article Node 2',
pub_date=datetime.date.today(),
reporter=r,
editor=r,
lang='en'
)
schema = graphene.Schema(query=Query)
query = '''
query ReporterPromiseConnectionQuery {
allReporters(first: 1) {
edges {
node {
id
articles(first: 2) {
edges {
node {
headline
}
}
}
}
}
}
}
'''
expected = {
'allReporters': {
'edges': [{
'node': {
'id': 'UmVwb3J0ZXJUeXBlOjE=',
'articles': {
'edges': [{
'node': {
'headline': 'Article Node 1',
}
}, {
'node': {
'headline': 'Article Node 2'
}
}]
}
}
}]
}
}
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -47,6 +47,7 @@ setup(
'Django>=1.8.0', 'Django>=1.8.0',
'iso8601', 'iso8601',
'singledispatch>=3.4.0.3', 'singledispatch>=3.4.0.3',
'promise>=2.0',
], ],
setup_requires=[ setup_requires=[
'pytest-runner', 'pytest-runner',