diff --git a/graphene/contrib/django/fields.py b/graphene/contrib/django/fields.py index 8cc7838e..5a4d95ae 100644 --- a/graphene/contrib/django/fields.py +++ b/graphene/contrib/django/fields.py @@ -11,7 +11,7 @@ class DjangoConnectionField(ConnectionField): def wrap_resolved(self, value, instance, args, info): schema = info.schema.graphene_schema - return lazy_map(value, self.type.get_object_type(schema)) + return lazy_map(value, self.type) class LazyListField(Field): @@ -22,7 +22,7 @@ class LazyListField(Field): def resolver(self, instance, args, info): schema = info.schema.graphene_schema resolved = super(LazyListField, self).resolver(instance, args, info) - return lazy_map(resolved, self.get_object_type(schema)) + return lazy_map(resolved, self.type) class ConnectionOrListField(Field): @@ -30,12 +30,14 @@ class ConnectionOrListField(Field): def internal_type(self, schema): model_field = self.type field_object_type = model_field.get_object_type(schema) + if not field_object_type: + raise SkipField() if is_node(field_object_type): - field = DjangoConnectionField(model_field) + field = DjangoConnectionField(field_object_type) else: - field = LazyListField(model_field) + field = LazyListField(field_object_type) field.contribute_to_class(self.object_type, self.name) - return field.internal_type(schema) + return schema.T(field) class DjangoModelField(FieldType): diff --git a/graphene/relay/fields.py b/graphene/relay/fields.py index 285d5a69..b3f6d328 100644 --- a/graphene/relay/fields.py +++ b/graphene/relay/fields.py @@ -10,12 +10,12 @@ from ..core.types.scalars import ID, Int, String class ConnectionField(Field): - def __init__(self, field_type, resolver=None, description='', + def __init__(self, type, resolver=None, description='', connection_type=None, edge_type=None, **kwargs): super( ConnectionField, self).__init__( - field_type, + type, resolver=resolver, before=String(), after=String(), @@ -38,7 +38,6 @@ class ConnectionField(Field): resolved = self.wrap_resolved(resolved, instance, args, info) assert isinstance( resolved, Iterable), 'Resolved value from the connection field have to be iterable' - type = schema.T(self.type) node = schema.objecttype(type) connection_type = self.get_connection_type(node) @@ -56,7 +55,8 @@ class ConnectionField(Field): return connection_type.for_node(node, edge_type=edge_type) def get_edge_type(self, node): - return self.edge_type or node.get_edge_type() + edge_type = self.edge_type or node.get_edge_type() + return edge_type.for_node(node) def get_type(self, schema): from graphene.relay.utils import is_node @@ -65,6 +65,7 @@ class ConnectionField(Field): assert is_node(node), 'Only nodes have connections.' schema.register(node) connection_type = self.get_connection_type(node) + return connection_type diff --git a/graphene/relay/tests/test_types.py b/graphene/relay/tests/test_types.py index da19eec4..1f3c56bf 100644 --- a/graphene/relay/tests/test_types.py +++ b/graphene/relay/tests/test_types.py @@ -1,4 +1,5 @@ from pytest import raises +from graphql.core.type import GraphQLList import graphene from graphene import relay @@ -31,3 +32,15 @@ def test_node_should_have_same_connection_always(): def test_node_should_have_id_field(): assert 'id' in OtherNode._meta.fields_map + + +def test_node_connection_should_have_edge(): + connection = relay.Connection.for_node(OtherNode) + edge = relay.Edge.for_node(OtherNode) + connection_type = schema.T(connection) + connection_fields = connection_type.get_fields() + assert 'edges' in connection_fields + assert 'pageInfo' in connection_fields + edges_type = connection_fields['edges'].type + assert isinstance(edges_type, GraphQLList) + assert edges_type.of_type == schema.T(edge) diff --git a/graphene/relay/types.py b/graphene/relay/types.py index 5b72e843..ec07b5c5 100644 --- a/graphene/relay/types.py +++ b/graphene/relay/types.py @@ -24,11 +24,6 @@ class PageInfo(ObjectType): class Edge(ObjectType): '''An edge in a connection.''' - class Meta: - type_name = 'DefaultEdge' - - node = Field(LazyType(lambda object_type: object_type.node_type), - description='The item at the end of the edge') cursor = String( required=True, description='A cursor for use in pagination') @@ -37,10 +32,11 @@ class Edge(ObjectType): def for_node(cls, node): from graphene.relay.utils import is_node assert is_node(node), 'ObjectTypes in a edge have to be Nodes' + node_field = Field(node, description='The item at the end of the edge') return type( '%s%s' % (node._meta.type_name, cls._meta.type_name), (cls,), - {'node_type': node}) + {'node_type': node, 'node': node_field}) class Connection(ObjectType): @@ -50,8 +46,6 @@ class Connection(ObjectType): page_info = Field(PageInfo, required=True, description='The Information to aid in pagination') - edges = List(LazyType(lambda object_type: object_type.edge_type), - description='Information to aid in pagination.') _connection_data = None @@ -59,12 +53,13 @@ class Connection(ObjectType): @memoize def for_node(cls, node, edge_type=None): from graphene.relay.utils import is_node - edge_type = edge_type or Edge + edge_type = edge_type or Edge.for_node(node) assert is_node(node), 'ObjectTypes in a connection have to be Nodes' + edges = List(edge_type, description='Information to aid in pagination.') return type( '%s%s' % (node._meta.type_name, cls._meta.type_name), (cls,), - {'edge_type': edge_type.for_node(node)}) + {'edge_type': edge_type, 'edges': edges}) def set_connection_data(self, data): self._connection_data = data