mirror of
https://github.com/graphql-python/graphene.git
synced 2025-07-27 00:09:45 +03:00
Merge 1e1dc4c83a
into 17ba01570a
This commit is contained in:
commit
40256e7aeb
|
@ -1,4 +1,5 @@
|
||||||
import six
|
import six
|
||||||
|
import functools
|
||||||
|
|
||||||
from graphql_relay.node.node import from_global_id
|
from graphql_relay.node.node import from_global_id
|
||||||
|
|
||||||
|
@ -9,6 +10,10 @@ from ..utils.wrap_resolver_function import has_context, with_context
|
||||||
from .connection import Connection, Edge
|
from .connection import Connection, Edge
|
||||||
|
|
||||||
|
|
||||||
|
def _is_thenable(obj):
|
||||||
|
return callable(getattr(obj, "then", None))
|
||||||
|
|
||||||
|
|
||||||
class ConnectionField(Field):
|
class ConnectionField(Field):
|
||||||
|
|
||||||
def __init__(self, type, resolver=None, description='',
|
def __init__(self, type, resolver=None, description='',
|
||||||
|
@ -27,6 +32,11 @@ class ConnectionField(Field):
|
||||||
self.connection_type = connection_type or Connection
|
self.connection_type = connection_type or Connection
|
||||||
self.edge_type = edge_type or Edge
|
self.edge_type = edge_type or Edge
|
||||||
|
|
||||||
|
def _get_connection_type(self, connection_type, args, context, info, resolved):
|
||||||
|
if isinstance(resolved, self.connection_type):
|
||||||
|
return resolved
|
||||||
|
return self.from_list(connection_type, resolved, args, context, info)
|
||||||
|
|
||||||
@with_context
|
@with_context
|
||||||
def resolver(self, instance, args, context, info):
|
def resolver(self, instance, args, context, info):
|
||||||
schema = info.schema.graphene_schema
|
schema = info.schema.graphene_schema
|
||||||
|
@ -38,9 +48,12 @@ class ConnectionField(Field):
|
||||||
else:
|
else:
|
||||||
resolved = super(ConnectionField, self).resolver(instance, args, info)
|
resolved = super(ConnectionField, self).resolver(instance, args, info)
|
||||||
|
|
||||||
if isinstance(resolved, self.connection_type):
|
get_connection_type = functools.partial(self._get_connection_type, connection_type, args, context, info)
|
||||||
return resolved
|
|
||||||
return self.from_list(connection_type, resolved, args, context, info)
|
if _is_thenable(resolved):
|
||||||
|
return resolved.then(get_connection_type)
|
||||||
|
|
||||||
|
return get_connection_type(resolved)
|
||||||
|
|
||||||
def from_list(self, connection_type, resolved, args, context, info):
|
def from_list(self, connection_type, resolved, args, context, info):
|
||||||
return connection_type.from_list(resolved, args, context, info)
|
return connection_type.from_list(resolved, args, context, info)
|
||||||
|
|
|
@ -4,6 +4,8 @@ from graphql.type import GraphQLID, GraphQLNonNull
|
||||||
import graphene
|
import graphene
|
||||||
from graphene import relay, with_context
|
from graphene import relay, with_context
|
||||||
|
|
||||||
|
from promise import Promise
|
||||||
|
|
||||||
schema = graphene.Schema()
|
schema = graphene.Schema()
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,6 +54,9 @@ class Query(graphene.ObjectType):
|
||||||
connection_type_nodes = relay.ConnectionField(
|
connection_type_nodes = relay.ConnectionField(
|
||||||
MyNode, connection_type=MyConnection)
|
MyNode, connection_type=MyConnection)
|
||||||
|
|
||||||
|
promise_connection_type = relay.ConnectionField(
|
||||||
|
MyNode, connection_type=MyConnection)
|
||||||
|
|
||||||
all_my_objects = relay.ConnectionField(
|
all_my_objects = relay.ConnectionField(
|
||||||
MyObject, connection_type=MyConnection)
|
MyObject, connection_type=MyConnection)
|
||||||
|
|
||||||
|
@ -76,6 +81,9 @@ class Query(graphene.ObjectType):
|
||||||
def resolve_all_my_objects(self, args, info):
|
def resolve_all_my_objects(self, args, info):
|
||||||
return [MyObject(name='my_object')]
|
return [MyObject(name='my_object')]
|
||||||
|
|
||||||
|
def resolve_promise_connection_type(self, args, info):
|
||||||
|
return Promise.resolve('async name').then(lambda name: [MyNode(id='1', name=name)])
|
||||||
|
|
||||||
schema.query = Query
|
schema.query = Query
|
||||||
|
|
||||||
|
|
||||||
|
@ -228,6 +236,32 @@ def test_connectionfield_resolve_returning_objects():
|
||||||
assert result.data == expected
|
assert result.data == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_connectionfield_resolve_returning_promise():
|
||||||
|
query = '''
|
||||||
|
query RebelsShipsQuery {
|
||||||
|
promiseConnectionType {
|
||||||
|
edges {
|
||||||
|
node {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
expected = {
|
||||||
|
'promiseConnectionType': {
|
||||||
|
'edges': [{
|
||||||
|
'node': {
|
||||||
|
'name': 'async name'
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = schema.execute(query)
|
||||||
|
assert not result.errors
|
||||||
|
assert result.data == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
|
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
|
||||||
def test_get_node_info(specialness, value):
|
def test_get_node_info(specialness, value):
|
||||||
query = '''
|
query = '''
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -66,6 +66,7 @@ setup(
|
||||||
'sqlalchemy',
|
'sqlalchemy',
|
||||||
'sqlalchemy_utils',
|
'sqlalchemy_utils',
|
||||||
'mock',
|
'mock',
|
||||||
|
'promse',
|
||||||
# Required for Django postgres fields testing
|
# Required for Django postgres fields testing
|
||||||
'psycopg2',
|
'psycopg2',
|
||||||
],
|
],
|
||||||
|
|
Loading…
Reference in New Issue
Block a user