mirror of
https://github.com/graphql-python/graphene-django.git
synced 2024-11-22 17:47:12 +03:00
Merge pull request #180 from arianon/master
Make DjangoConnectionField compatible with Promise-based iterables.
This commit is contained in:
commit
7c52aa3c7f
|
@ -2,6 +2,8 @@ from functools import partial
|
||||||
|
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
|
|
||||||
|
from promise import Promise
|
||||||
|
|
||||||
from graphene.types import Field, List
|
from graphene.types import Field, List
|
||||||
from graphene.relay import ConnectionField, PageInfo
|
from graphene.relay import ConnectionField, PageInfo
|
||||||
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
||||||
|
@ -60,30 +62,7 @@ class DjangoConnectionField(ConnectionField):
|
||||||
return default_queryset & queryset
|
return default_queryset & queryset
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
|
def resolve_connection(cls, connection, default_manager, args, iterable):
|
||||||
enforce_first_or_last, root, args, context, info):
|
|
||||||
first = args.get('first')
|
|
||||||
last = args.get('last')
|
|
||||||
|
|
||||||
if enforce_first_or_last:
|
|
||||||
assert first or last, (
|
|
||||||
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
|
|
||||||
).format(info.field_name)
|
|
||||||
|
|
||||||
if max_limit:
|
|
||||||
if first:
|
|
||||||
assert first <= max_limit, (
|
|
||||||
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
|
|
||||||
).format(first, info.field_name, max_limit)
|
|
||||||
args['first'] = min(first, max_limit)
|
|
||||||
|
|
||||||
if last:
|
|
||||||
assert last <= max_limit, (
|
|
||||||
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
|
|
||||||
).format(first, info.field_name, max_limit)
|
|
||||||
args['last'] = min(last, max_limit)
|
|
||||||
|
|
||||||
iterable = resolver(root, args, context, info)
|
|
||||||
if iterable is None:
|
if iterable is None:
|
||||||
iterable = default_manager
|
iterable = default_manager
|
||||||
iterable = maybe_queryset(iterable)
|
iterable = maybe_queryset(iterable)
|
||||||
|
@ -108,6 +87,38 @@ class DjangoConnectionField(ConnectionField):
|
||||||
connection.length = _len
|
connection.length = _len
|
||||||
return connection
|
return connection
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def connection_resolver(cls, resolver, connection, default_manager, max_limit,
|
||||||
|
enforce_first_or_last, root, args, context, info):
|
||||||
|
first = args.get('first')
|
||||||
|
last = args.get('last')
|
||||||
|
|
||||||
|
if enforce_first_or_last:
|
||||||
|
assert first or last, (
|
||||||
|
'You must provide a `first` or `last` value to properly paginate the `{}` connection.'
|
||||||
|
).format(info.field_name)
|
||||||
|
|
||||||
|
if max_limit:
|
||||||
|
if first:
|
||||||
|
assert first <= max_limit, (
|
||||||
|
'Requesting {} records on the `{}` connection exceeds the `first` limit of {} records.'
|
||||||
|
).format(first, info.field_name, max_limit)
|
||||||
|
args['first'] = min(first, max_limit)
|
||||||
|
|
||||||
|
if last:
|
||||||
|
assert last <= max_limit, (
|
||||||
|
'Requesting {} records on the `{}` connection exceeds the `last` limit of {} records.'
|
||||||
|
).format(first, info.field_name, max_limit)
|
||||||
|
args['last'] = min(last, max_limit)
|
||||||
|
|
||||||
|
iterable = resolver(root, args, context, info)
|
||||||
|
on_resolve = partial(cls.resolve_connection, connection, default_manager, args)
|
||||||
|
|
||||||
|
if Promise.is_thenable(iterable):
|
||||||
|
return Promise.resolve(iterable).then(on_resolve)
|
||||||
|
|
||||||
|
return on_resolve(iterable)
|
||||||
|
|
||||||
def get_resolver(self, parent_resolver):
|
def get_resolver(self, parent_resolver):
|
||||||
return partial(
|
return partial(
|
||||||
self.connection_resolver,
|
self.connection_resolver,
|
||||||
|
|
|
@ -545,3 +545,146 @@ def test_should_error_if_first_is_greater_than_max():
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
|
||||||
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
|
graphene_settings.RELAY_CONNECTION_ENFORCE_FIRST_OR_LAST = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_query_promise_connectionfields():
|
||||||
|
from promise import Promise
|
||||||
|
|
||||||
|
class ReporterType(DjangoObjectType):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Reporter
|
||||||
|
interfaces = (Node, )
|
||||||
|
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
all_reporters = DjangoConnectionField(ReporterType)
|
||||||
|
|
||||||
|
def resolve_all_reporters(self, *args, **kwargs):
|
||||||
|
return Promise.resolve([Reporter(id=1)])
|
||||||
|
|
||||||
|
schema = graphene.Schema(query=Query)
|
||||||
|
query = '''
|
||||||
|
query ReporterPromiseConnectionQuery {
|
||||||
|
allReporters(first: 1) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
'allReporters': {
|
||||||
|
'edges': [{
|
||||||
|
'node': {
|
||||||
|
'id': 'UmVwb3J0ZXJUeXBlOjE='
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_should_query_dataloader_fields():
|
||||||
|
from promise import Promise
|
||||||
|
from promise.dataloader import DataLoader
|
||||||
|
|
||||||
|
def article_batch_load_fn(keys):
|
||||||
|
queryset = Article.objects.filter(reporter_id__in=keys)
|
||||||
|
return Promise.resolve([
|
||||||
|
[article for article in queryset if article.reporter_id == id]
|
||||||
|
for id in keys
|
||||||
|
])
|
||||||
|
|
||||||
|
article_loader = DataLoader(article_batch_load_fn)
|
||||||
|
|
||||||
|
class ArticleType(DjangoObjectType):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Article
|
||||||
|
interfaces = (Node, )
|
||||||
|
|
||||||
|
class ReporterType(DjangoObjectType):
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = Reporter
|
||||||
|
interfaces = (Node, )
|
||||||
|
|
||||||
|
articles = DjangoConnectionField(ArticleType)
|
||||||
|
|
||||||
|
def resolve_articles(self, *args, **kwargs):
|
||||||
|
return article_loader.load(self.id)
|
||||||
|
|
||||||
|
class Query(graphene.ObjectType):
|
||||||
|
all_reporters = DjangoConnectionField(ReporterType)
|
||||||
|
|
||||||
|
r = Reporter.objects.create(
|
||||||
|
first_name='John',
|
||||||
|
last_name='Doe',
|
||||||
|
email='johndoe@example.com',
|
||||||
|
a_choice=1
|
||||||
|
)
|
||||||
|
Article.objects.create(
|
||||||
|
headline='Article Node 1',
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
reporter=r,
|
||||||
|
editor=r,
|
||||||
|
lang='es'
|
||||||
|
)
|
||||||
|
Article.objects.create(
|
||||||
|
headline='Article Node 2',
|
||||||
|
pub_date=datetime.date.today(),
|
||||||
|
reporter=r,
|
||||||
|
editor=r,
|
||||||
|
lang='en'
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = graphene.Schema(query=Query)
|
||||||
|
query = '''
|
||||||
|
query ReporterPromiseConnectionQuery {
|
||||||
|
allReporters(first: 1) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
id
|
||||||
|
articles(first: 2) {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
headline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
'allReporters': {
|
||||||
|
'edges': [{
|
||||||
|
'node': {
|
||||||
|
'id': 'UmVwb3J0ZXJUeXBlOjE=',
|
||||||
|
'articles': {
|
||||||
|
'edges': [{
|
||||||
|
'node': {
|
||||||
|
'headline': 'Article Node 1',
|
||||||
|
}
|
||||||
|
}, {
|
||||||
|
'node': {
|
||||||
|
'headline': 'Article Node 2'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == expected
|
||||||
|
|
Loading…
Reference in New Issue
Block a user