mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-11-04 09:57:41 +03:00 
			
		
		
		
	Merge pull request #237 from Globegitter/connection-fields-object-types
Allow ConnectionFields to have ObjectTypes as per relay spec
This commit is contained in:
		
						commit
						bd1fbb6a33
					
				| 
						 | 
				
			
			@ -4,7 +4,8 @@ import six
 | 
			
		|||
from django.db import models
 | 
			
		||||
 | 
			
		||||
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
 | 
			
		||||
from ...relay.types import Connection, Node, NodeMeta
 | 
			
		||||
from ...relay.types import Node, NodeMeta
 | 
			
		||||
from ...relay.connection import Connection
 | 
			
		||||
from .converter import convert_django_field_with_choices
 | 
			
		||||
from .options import DjangoOptions
 | 
			
		||||
from .utils import get_reverse_fields
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,7 +5,8 @@ from sqlalchemy.inspection import inspect as sqlalchemyinspect
 | 
			
		|||
from sqlalchemy.orm.exc import NoResultFound
 | 
			
		||||
 | 
			
		||||
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
 | 
			
		||||
from ...relay.types import Connection, Node, NodeMeta
 | 
			
		||||
from ...relay.types import Node, NodeMeta
 | 
			
		||||
from ...relay.connection import Connection
 | 
			
		||||
from .converter import (convert_sqlalchemy_column,
 | 
			
		||||
                        convert_sqlalchemy_relationship)
 | 
			
		||||
from .options import SQLAlchemyOptions
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,12 +6,15 @@ from .fields import (
 | 
			
		|||
 | 
			
		||||
from .types import (
 | 
			
		||||
    Node,
 | 
			
		||||
    PageInfo,
 | 
			
		||||
    Edge,
 | 
			
		||||
    Connection,
 | 
			
		||||
    ClientIDMutation
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .connection import (
 | 
			
		||||
    PageInfo,
 | 
			
		||||
    Connection,
 | 
			
		||||
    Edge,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .utils import is_node
 | 
			
		||||
 | 
			
		||||
__all__ = ['ConnectionField', 'NodeField', 'GlobalIDField', 'Node',
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										87
									
								
								graphene/relay/connection.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								graphene/relay/connection.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,87 @@
 | 
			
		|||
from collections import Iterable
 | 
			
		||||
 | 
			
		||||
from graphql_relay.connection.arrayconnection import connection_from_list
 | 
			
		||||
 | 
			
		||||
from ..core.classtypes import ObjectType
 | 
			
		||||
from ..core.types import Field, Boolean, String, List
 | 
			
		||||
from ..utils import memoize
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PageInfo(ObjectType):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, start_cursor="", end_cursor="",
 | 
			
		||||
                 has_previous_page=False, has_next_page=False, **kwargs):
 | 
			
		||||
        super(PageInfo, self).__init__(**kwargs)
 | 
			
		||||
        self.startCursor = start_cursor
 | 
			
		||||
        self.endCursor = end_cursor
 | 
			
		||||
        self.hasPreviousPage = has_previous_page
 | 
			
		||||
        self.hasNextPage = has_next_page
 | 
			
		||||
 | 
			
		||||
    hasNextPage = Boolean(
 | 
			
		||||
        required=True,
 | 
			
		||||
        description='When paginating forwards, are there more items?')
 | 
			
		||||
    hasPreviousPage = Boolean(
 | 
			
		||||
        required=True,
 | 
			
		||||
        description='When paginating backwards, are there more items?')
 | 
			
		||||
    startCursor = String(
 | 
			
		||||
        description='When paginating backwards, the cursor to continue.')
 | 
			
		||||
    endCursor = String(
 | 
			
		||||
        description='When paginating forwards, the cursor to continue.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Edge(ObjectType):
 | 
			
		||||
    '''An edge in a connection.'''
 | 
			
		||||
    cursor = String(
 | 
			
		||||
        required=True, description='A cursor for use in pagination')
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @memoize
 | 
			
		||||
    def for_node(cls, node):
 | 
			
		||||
        node_field = Field(node, description='The item at the end of the edge')
 | 
			
		||||
        return type(
 | 
			
		||||
            '%s%s' % (node._meta.type_name, cls._meta.type_name),
 | 
			
		||||
            (cls,),
 | 
			
		||||
            {'node_type': node, 'node': node_field})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Connection(ObjectType):
 | 
			
		||||
    '''A connection to a list of items.'''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, edges, page_info, **kwargs):
 | 
			
		||||
        super(Connection, self).__init__(**kwargs)
 | 
			
		||||
        self.edges = edges
 | 
			
		||||
        self.pageInfo = page_info
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        type_name = 'DefaultConnection'
 | 
			
		||||
 | 
			
		||||
    pageInfo = Field(PageInfo, required=True,
 | 
			
		||||
                     description='The Information to aid in pagination')
 | 
			
		||||
 | 
			
		||||
    _connection_data = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @memoize
 | 
			
		||||
    def for_node(cls, node, edge_type=None):
 | 
			
		||||
        edge_type = edge_type or Edge.for_node(node)
 | 
			
		||||
        edges = List(edge_type, description='Information to aid in pagination.')
 | 
			
		||||
        return type(
 | 
			
		||||
            '%s%s' % (node._meta.type_name, cls._meta.type_name),
 | 
			
		||||
            (cls,),
 | 
			
		||||
            {'edge_type': edge_type, 'edges': edges})
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_list(cls, iterable, args, context, info):
 | 
			
		||||
        assert isinstance(
 | 
			
		||||
            iterable, Iterable), 'Resolved value from the connection field have to be iterable'
 | 
			
		||||
        connection = connection_from_list(
 | 
			
		||||
            iterable, args, connection_type=cls,
 | 
			
		||||
            edge_type=cls.edge_type, pageinfo_type=PageInfo)
 | 
			
		||||
        connection.set_connection_data(iterable)
 | 
			
		||||
        return connection
 | 
			
		||||
 | 
			
		||||
    def set_connection_data(self, data):
 | 
			
		||||
        self._connection_data = data
 | 
			
		||||
 | 
			
		||||
    def get_connection_data(self):
 | 
			
		||||
        return self._connection_data
 | 
			
		||||
| 
						 | 
				
			
			@ -6,6 +6,7 @@ from ..core.fields import Field
 | 
			
		|||
from ..core.types.definitions import NonNull
 | 
			
		||||
from ..core.types.scalars import ID, Int, String
 | 
			
		||||
from ..utils.wrap_resolver_function import has_context, with_context
 | 
			
		||||
from .connection import Connection, Edge
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConnectionField(Field):
 | 
			
		||||
| 
						 | 
				
			
			@ -23,8 +24,8 @@ class ConnectionField(Field):
 | 
			
		|||
            last=Int(),
 | 
			
		||||
            description=description,
 | 
			
		||||
            **kwargs)
 | 
			
		||||
        self.connection_type = connection_type
 | 
			
		||||
        self.edge_type = edge_type
 | 
			
		||||
        self.connection_type = connection_type or Connection
 | 
			
		||||
        self.edge_type = edge_type or Edge
 | 
			
		||||
 | 
			
		||||
    @with_context
 | 
			
		||||
    def resolver(self, instance, args, context, info):
 | 
			
		||||
| 
						 | 
				
			
			@ -37,7 +38,7 @@ class ConnectionField(Field):
 | 
			
		|||
        else:
 | 
			
		||||
            resolved = super(ConnectionField, self).resolver(instance, args, info)
 | 
			
		||||
 | 
			
		||||
        if isinstance(resolved, connection_type):
 | 
			
		||||
        if isinstance(resolved, self.connection_type):
 | 
			
		||||
            return resolved
 | 
			
		||||
        return self.from_list(connection_type, resolved, args, context, info)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -45,19 +46,14 @@ class ConnectionField(Field):
 | 
			
		|||
        return connection_type.from_list(resolved, args, context, info)
 | 
			
		||||
 | 
			
		||||
    def get_connection_type(self, node):
 | 
			
		||||
        connection_type = self.connection_type or node.get_connection_type()
 | 
			
		||||
        edge_type = self.get_edge_type(node)
 | 
			
		||||
        return connection_type.for_node(node, edge_type=edge_type)
 | 
			
		||||
        return self.connection_type.for_node(node)
 | 
			
		||||
 | 
			
		||||
    def get_edge_type(self, node):
 | 
			
		||||
        edge_type = self.edge_type or node.get_edge_type()
 | 
			
		||||
        return edge_type.for_node(node)
 | 
			
		||||
        return self.edge_type.for_node(node)
 | 
			
		||||
 | 
			
		||||
    def get_type(self, schema):
 | 
			
		||||
        from graphene.relay.utils import is_node
 | 
			
		||||
        type = schema.T(self.type)
 | 
			
		||||
        node = schema.objecttype(type)
 | 
			
		||||
        assert is_node(node), 'Only nodes have connections.'
 | 
			
		||||
        schema.register(node)
 | 
			
		||||
        connection_type = self.get_connection_type(node)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -19,6 +19,9 @@ class MyNode(relay.Node):
 | 
			
		|||
    def get_node(cls, id, info):
 | 
			
		||||
        return MyNode(id=id, name='mo')
 | 
			
		||||
 | 
			
		||||
class MyObject(graphene.ObjectType):
 | 
			
		||||
    name = graphene.String()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SpecialNode(relay.Node):
 | 
			
		||||
    value = graphene.String()
 | 
			
		||||
| 
						 | 
				
			
			@ -29,6 +32,9 @@ class SpecialNode(relay.Node):
 | 
			
		|||
        value = "!!!" if context.get('is_special') else "???"
 | 
			
		||||
        return SpecialNode(id=id, value=value)
 | 
			
		||||
 | 
			
		||||
def _create_my_node_edge(myNode):
 | 
			
		||||
    edge_type = relay.Edge.for_node(MyNode)
 | 
			
		||||
    return edge_type(node=myNode, cursor=str(myNode.id))
 | 
			
		||||
 | 
			
		||||
class Query(graphene.ObjectType):
 | 
			
		||||
    my_node = relay.NodeField(MyNode)
 | 
			
		||||
| 
						 | 
				
			
			@ -40,6 +46,12 @@ class Query(graphene.ObjectType):
 | 
			
		|||
    context_nodes = relay.ConnectionField(
 | 
			
		||||
        MyNode, connection_type=MyConnection, customArg=graphene.String())
 | 
			
		||||
 | 
			
		||||
    connection_type_nodes = relay.ConnectionField(
 | 
			
		||||
        MyNode, connection_type=MyConnection)
 | 
			
		||||
 | 
			
		||||
    all_my_objects = relay.ConnectionField(
 | 
			
		||||
        MyObject, connection_type=MyConnection)
 | 
			
		||||
 | 
			
		||||
    def resolve_all_my_nodes(self, args, info):
 | 
			
		||||
        custom_arg = args.get('customArg')
 | 
			
		||||
        assert custom_arg == "1"
 | 
			
		||||
| 
						 | 
				
			
			@ -51,6 +63,16 @@ class Query(graphene.ObjectType):
 | 
			
		|||
        assert custom_arg == "1"
 | 
			
		||||
        return [MyNode(name='my')]
 | 
			
		||||
 | 
			
		||||
    def resolve_connection_type_nodes(self, args, info):
 | 
			
		||||
        edges = [_create_my_node_edge(n) for n in [MyNode(id='1', name='my')]]
 | 
			
		||||
        connection_type = MyConnection.for_node(MyNode)
 | 
			
		||||
 | 
			
		||||
        return connection_type(
 | 
			
		||||
            edges=edges, page_info=relay.PageInfo(has_next_page=True))
 | 
			
		||||
 | 
			
		||||
    def resolve_all_my_objects(self, args, info):
 | 
			
		||||
        return [MyObject(name='my_object')]
 | 
			
		||||
 | 
			
		||||
schema.query = Query
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -135,6 +157,74 @@ def test_connectionfield_context_query():
 | 
			
		|||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_connectionfield_resolve_returns_connection_type_directly():
 | 
			
		||||
    query = '''
 | 
			
		||||
    query RebelsShipsQuery {
 | 
			
		||||
      connectionTypeNodes {
 | 
			
		||||
        edges {
 | 
			
		||||
          node {
 | 
			
		||||
            name
 | 
			
		||||
          }
 | 
			
		||||
        },
 | 
			
		||||
        myCustomField
 | 
			
		||||
        pageInfo {
 | 
			
		||||
          hasNextPage
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'connectionTypeNodes': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'name': 'my'
 | 
			
		||||
                }
 | 
			
		||||
            }],
 | 
			
		||||
            'myCustomField': 'Custom',
 | 
			
		||||
            'pageInfo': {
 | 
			
		||||
                'hasNextPage': True,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_connectionfield_resolve_returning_objects():
 | 
			
		||||
    query = '''
 | 
			
		||||
    query RebelsShipsQuery {
 | 
			
		||||
      allMyObjects {
 | 
			
		||||
        edges {
 | 
			
		||||
          node {
 | 
			
		||||
            name
 | 
			
		||||
          }
 | 
			
		||||
        },
 | 
			
		||||
        myCustomField
 | 
			
		||||
        pageInfo {
 | 
			
		||||
          hasNextPage
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    '''
 | 
			
		||||
    expected = {
 | 
			
		||||
        'allMyObjects': {
 | 
			
		||||
            'edges': [{
 | 
			
		||||
                'node': {
 | 
			
		||||
                    'name': 'my_object'
 | 
			
		||||
                }
 | 
			
		||||
            }],
 | 
			
		||||
            'myCustomField': 'Custom',
 | 
			
		||||
            'pageInfo': {
 | 
			
		||||
                'hasNextPage': False,
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    result = schema.execute(query)
 | 
			
		||||
    assert not result.errors
 | 
			
		||||
    assert result.data == expected
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('specialness,value', [(True, '!!!'), (False, '???')])
 | 
			
		||||
def test_get_node_info(specialness, value):
 | 
			
		||||
    query = '''
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,108 +1,21 @@
 | 
			
		|||
import inspect
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import Iterable
 | 
			
		||||
from functools import wraps
 | 
			
		||||
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
from graphql_relay.connection.arrayconnection import connection_from_list
 | 
			
		||||
from graphql_relay.node.node import to_global_id
 | 
			
		||||
 | 
			
		||||
from ..core.classtypes import InputObjectType, Interface, Mutation, ObjectType
 | 
			
		||||
from ..core.classtypes import InputObjectType, Interface, Mutation
 | 
			
		||||
from ..core.classtypes.interface import InterfaceMeta
 | 
			
		||||
from ..core.classtypes.mutation import MutationMeta
 | 
			
		||||
from ..core.types import Boolean, Field, List, String
 | 
			
		||||
from ..core.types import String
 | 
			
		||||
from ..core.types.argument import ArgumentsGroup
 | 
			
		||||
from ..core.types.definitions import NonNull
 | 
			
		||||
from ..utils import memoize
 | 
			
		||||
from ..utils.wrap_resolver_function import has_context, with_context
 | 
			
		||||
from .fields import GlobalIDField
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PageInfo(ObjectType):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, start_cursor="", end_cursor="",
 | 
			
		||||
                 has_previous_page=False, has_next_page=False, **kwargs):
 | 
			
		||||
        super(PageInfo, self).__init__(**kwargs)
 | 
			
		||||
        self.startCursor = start_cursor
 | 
			
		||||
        self.endCursor = end_cursor
 | 
			
		||||
        self.hasPreviousPage = has_previous_page
 | 
			
		||||
        self.hasNextPage = has_next_page
 | 
			
		||||
 | 
			
		||||
    hasNextPage = Boolean(
 | 
			
		||||
        required=True,
 | 
			
		||||
        description='When paginating forwards, are there more items?')
 | 
			
		||||
    hasPreviousPage = Boolean(
 | 
			
		||||
        required=True,
 | 
			
		||||
        description='When paginating backwards, are there more items?')
 | 
			
		||||
    startCursor = String(
 | 
			
		||||
        description='When paginating backwards, the cursor to continue.')
 | 
			
		||||
    endCursor = String(
 | 
			
		||||
        description='When paginating forwards, the cursor to continue.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Edge(ObjectType):
 | 
			
		||||
    '''An edge in a connection.'''
 | 
			
		||||
    cursor = String(
 | 
			
		||||
        required=True, description='A cursor for use in pagination')
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @memoize
 | 
			
		||||
    def for_node(cls, node):
 | 
			
		||||
        from graphene.relay.utils import is_node
 | 
			
		||||
        assert is_node(node), 'ObjectTypes in a edge have to be Nodes'
 | 
			
		||||
        node_field = Field(node, description='The item at the end of the edge')
 | 
			
		||||
        return type(
 | 
			
		||||
            '%s%s' % (node._meta.type_name, cls._meta.type_name),
 | 
			
		||||
            (cls,),
 | 
			
		||||
            {'node_type': node, 'node': node_field})
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Connection(ObjectType):
 | 
			
		||||
    '''A connection to a list of items.'''
 | 
			
		||||
 | 
			
		||||
    def __init__(self, edges, page_info, **kwargs):
 | 
			
		||||
        super(Connection, self).__init__(**kwargs)
 | 
			
		||||
        self.edges = edges
 | 
			
		||||
        self.pageInfo = page_info
 | 
			
		||||
 | 
			
		||||
    class Meta:
 | 
			
		||||
        type_name = 'DefaultConnection'
 | 
			
		||||
 | 
			
		||||
    pageInfo = Field(PageInfo, required=True,
 | 
			
		||||
                     description='The Information to aid in pagination')
 | 
			
		||||
 | 
			
		||||
    _connection_data = None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    @memoize
 | 
			
		||||
    def for_node(cls, node, edge_type=None):
 | 
			
		||||
        from graphene.relay.utils import is_node
 | 
			
		||||
        edge_type = edge_type or Edge.for_node(node)
 | 
			
		||||
        assert is_node(node), 'ObjectTypes in a connection have to be Nodes'
 | 
			
		||||
        edges = List(edge_type, description='Information to aid in pagination.')
 | 
			
		||||
        return type(
 | 
			
		||||
            '%s%s' % (node._meta.type_name, cls._meta.type_name),
 | 
			
		||||
            (cls,),
 | 
			
		||||
            {'edge_type': edge_type, 'edges': edges})
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_list(cls, iterable, args, context, info):
 | 
			
		||||
        assert isinstance(
 | 
			
		||||
            iterable, Iterable), 'Resolved value from the connection field have to be iterable'
 | 
			
		||||
        connection = connection_from_list(
 | 
			
		||||
            iterable, args, connection_type=cls,
 | 
			
		||||
            edge_type=cls.edge_type, pageinfo_type=PageInfo)
 | 
			
		||||
        connection.set_connection_data(iterable)
 | 
			
		||||
        return connection
 | 
			
		||||
 | 
			
		||||
    def set_connection_data(self, data):
 | 
			
		||||
        self._connection_data = data
 | 
			
		||||
 | 
			
		||||
    def get_connection_data(self):
 | 
			
		||||
        return self._connection_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NodeMeta(InterfaceMeta):
 | 
			
		||||
 | 
			
		||||
    def construct_get_node(cls):
 | 
			
		||||
| 
						 | 
				
			
			@ -153,17 +66,6 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
 | 
			
		|||
    def to_global_id(self):
 | 
			
		||||
        return self.global_id(self.id)
 | 
			
		||||
 | 
			
		||||
    connection_type = Connection
 | 
			
		||||
    edge_type = Edge
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_connection_type(cls):
 | 
			
		||||
        return cls.connection_type
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_edge_type(cls):
 | 
			
		||||
        return cls.edge_type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MutationInputType(InputObjectType):
 | 
			
		||||
    clientMutationId = String(required=True)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user