diff --git a/.gitignore b/.gitignore index b7554723..9f465556 100644 --- a/.gitignore +++ b/.gitignore @@ -75,6 +75,7 @@ target/ # PyCharm .idea +*.iml # Databases *.sqlite3 diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index 8ad1a643..32813afd 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -16,6 +16,8 @@ class Ship(graphene.ObjectType): def get_node(cls, id, context, info): return get_ship(id) +ShipConnection = relay.Connection.for_type(Ship) + class Faction(graphene.ObjectType): '''A faction in the Star Wars saga''' @@ -24,7 +26,7 @@ class Faction(graphene.ObjectType): interfaces = (relay.Node, ) name = graphene.String(description='The name of the faction.') - ships = relay.ConnectionField(Ship, description='The ships used by the faction.') + ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.') @resolve_only_args def resolve_ships(self, **args): diff --git a/examples/starwars_relay/tests/test_objectidentification.py b/examples/starwars_relay/tests/test_objectidentification.py index cd63dd54..126adfb7 100644 --- a/examples/starwars_relay/tests/test_objectidentification.py +++ b/examples/starwars_relay/tests/test_objectidentification.py @@ -54,13 +54,13 @@ type Ship implements Node { } type ShipConnection { - pageInfo: PageInfo! edges: [ShipEdge] + pageInfo: PageInfo! } type ShipEdge { - node: Ship cursor: String! + node: Ship } ''' diff --git a/graphene-django/graphene_django/converter.py b/graphene-django/graphene_django/converter.py index dc1eb88b..650b3cfe 100644 --- a/graphene-django/graphene_django/converter.py +++ b/graphene-django/graphene_django/converter.py @@ -1,7 +1,7 @@ from django.db import models from django.utils.encoding import force_text -from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull, Field, Dynamic +from graphene import Enum, List, ID, Boolean, Float, Int, String, NonNull, Field, Dynamic from graphene.types.json import JSONString from graphene.types.datetime import DateTime from graphene.utils.str_converters import to_const @@ -37,7 +37,7 @@ def convert_django_field_with_choices(field, registry=None): name = '{}{}'.format(meta.object_name, field.name.capitalize()) choices = list(get_choices(choices)) named_choices = [(c[0], c[1]) for c in choices] - named_choices_descriptions = {c[0]:c[2] for c in choices} + named_choices_descriptions = {c[0]: c[2] for c in choices} class EnumWithDescriptionsType(object): @property diff --git a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py index df967dc4..5c8518ef 100644 --- a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py +++ b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py @@ -30,8 +30,10 @@ class Role(SQLAlchemyObjectType): class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) - all_roles = SQLAlchemyConnectionField(Role) + employee_connection = relay.Connection.for_type(Employee) + role_connection = relay.Connection.for_type(Role) + all_employees = SQLAlchemyConnectionField(employee_connection) + all_roles = SQLAlchemyConnectionField(role_connection) role = graphene.Field(Role) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py index 89f58177..15ef1570 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import interfaces from sqlalchemy.dialects import postgresql from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic -from graphene.relay import is_node +from graphene.relay import is_node, Connection from graphene.types.json import JSONString from .fields import SQLAlchemyConnectionField @@ -18,9 +18,10 @@ except ImportError: pass -def convert_sqlalchemy_relationship(relationship, registry): +def convert_sqlalchemy_relationship(relationship, registry, connections, type_name): direction = relationship.direction model = relationship.mapper.entity + print(registry) def dynamic_type(): _type = registry.get_type_for_model(model) @@ -31,7 +32,13 @@ def convert_sqlalchemy_relationship(relationship, registry): elif (direction == interfaces.ONETOMANY or direction == interfaces.MANYTOMANY): if is_node(_type): - return SQLAlchemyConnectionField(_type) + try: + connection_type = connections[relationship.key] + except KeyError: + print(_type) + raise KeyError("No Connection provided for relationship {} on type {}. Specify it in its Meta " + "class on the 'connections' dict.".format(relationship.key, type_name)) + return SQLAlchemyConnectionField(connection_type) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py index 192f69e9..e2fcdf72 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py @@ -6,7 +6,7 @@ from sqlalchemy_utils import ChoiceType, ScalarListType from sqlalchemy.dialects import postgresql import graphene -from graphene.relay import Node +from graphene.relay import Node, Connection from graphene.types.json import JSONString from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -129,7 +129,7 @@ def test_should_scalar_list_convert_list(): def test_should_manytomany_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry, {}, '') assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -139,7 +139,7 @@ def test_should_manytomany_convert_connectionorlist_list(): class Meta: model = Pet - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -153,14 +153,33 @@ def test_should_manytomany_convert_connectionorlist_connection(): model = Pet interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + connections = {'pets': Connection.for_type(A)} + + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A') assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) + assert issubclass(dynamic_field.get_type().type, Connection) + + +def test_should_rais_when_no_connections_is_provided_for_manyto_many(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node, ) + + connections = {} + + with raises(KeyError) as ctx: + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A') + dynamic_field.get_type() + + assert str(ctx.value) == ('\"No Connection provided for relationship pets on type A. Specify it in its Meta ' + 'class on the \'connections\' dict.\"') def test_should_manytoone_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry, {}, '') assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -170,7 +189,7 @@ def test_should_manytoone_convert_connectionorlist_list(): class Meta: model = Reporter - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -183,7 +202,7 @@ def test_should_manytoone_convert_connectionorlist_connection(): model = Reporter interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -196,7 +215,7 @@ def test_should_onetoone_convert_field(): model = Article interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py index 8fa8d18e..32e84adb 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py @@ -3,11 +3,11 @@ from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker import graphene -from graphene.relay import Node +from graphene.relay import Node, Connection from ..types import SQLAlchemyObjectType from ..fields import SQLAlchemyConnectionField -from .models import Article, Base, Editor, Reporter +from .models import Article, Base, Editor, Reporter, Pet db = create_engine('sqlite:///test_sqlalchemy.sqlite3') @@ -46,10 +46,34 @@ def setup_fixtures(session): def test_should_query_well(session): setup_fixtures(session) + class ArticleType(SQLAlchemyObjectType): + + class Meta: + model = Article + + ArticleTypeConnection = Connection.for_type(ArticleType) + + ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterType)) + + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + connections = { + 'reporters': ReporterNodeConnection, + 'articles': ArticleTypeConnection, + } + interfaces = (Node, ) + + AConnection = Connection.for_type(A) + class ReporterType(SQLAlchemyObjectType): class Meta: model = Reporter + connections = { + 'pets': AConnection, + 'articles': AConnection, + } class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) @@ -94,16 +118,6 @@ def test_should_query_well(session): def test_should_node(session): setup_fixtures(session) - class ReporterNode(SQLAlchemyObjectType): - - class Meta: - model = Reporter - interfaces = (Node, ) - - @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') - class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -114,11 +128,39 @@ def test_should_node(session): # def get_node(cls, id, info): # return Article(id=1, headline='Article node') + ArticleNodeConnection = Connection.for_type(ArticleNode) + + ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterNode)) + + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + connections = { + 'reporters': ReporterNodeConnection + } + interfaces = (Node, ) + + AConnection = Connection.for_type(A) + + class ReporterNode(SQLAlchemyObjectType): + + class Meta: + model = Reporter + connections = { + 'articles': ArticleNodeConnection, + 'pets': AConnection, + } + interfaces = (Node, ) + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name='Cookie Monster') + class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleNodeConnection) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -202,7 +244,8 @@ def test_should_custom_identifier(session): class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorNode) + EditorNodeConnection = Connection.for_type(EditorNode) + all_editors = SQLAlchemyConnectionField(EditorNodeConnection) query = ''' query EditorQuery { @@ -250,23 +293,40 @@ def test_should_mutate_well(session): model = Editor interfaces = (Node, ) - - class ReporterNode(SQLAlchemyObjectType): - - class Meta: - model = Reporter - interfaces = (Node, ) - - @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') - class ArticleNode(SQLAlchemyObjectType): class Meta: model = Article interfaces = (Node, ) + ArticleNodeConnection = Connection.for_type(ArticleNode) + + ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterNode)) + + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + connections = { + 'reporters': ReporterNodeConnection + } + interfaces = (Node, ) + + AConnection = Connection.for_type(A) + + class ReporterNode(SQLAlchemyObjectType): + + class Meta: + model = Reporter + connections = { + 'articles': ArticleNodeConnection, + 'pets': AConnection, + } + interfaces = (Node, ) + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name='Cookie Monster') + class CreateArticle(graphene.Mutation): class Input: headline = graphene.String() @@ -279,7 +339,7 @@ def test_should_mutate_well(session): def mutate(cls, instance, args, context, info): new_article = Article( headline=args.get('headline'), - reporter_id = args.get('reporter_id'), + reporter_id=args.get('reporter_id'), ) session.add(new_article) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/types.py b/graphene-sqlalchemy/graphene_sqlalchemy/types.py index bade191f..e9057d82 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/types.py @@ -18,12 +18,13 @@ from graphene.types.utils import yank_fields_from_attrs, merge from .utils import get_query -def construct_fields(options): +def construct_fields(options, type_name): only_fields = options.only_fields exclude_fields = options.exclude_fields inspected_model = sqlalchemyinspect(options.model) fields = OrderedDict() + print('options in construct_fields', options) for name, column in inspected_model.columns.items(): is_not_in_only = only_fields and name not in only_fields @@ -56,7 +57,7 @@ def construct_fields(options): # We skip this field if we specify only_fields and is not # in there. Or when we excldue this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) + converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry, options.connections, type_name) name = relationship.key fields[name] = converted_relationship @@ -82,7 +83,8 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): exclude_fields=(), id='id', interfaces=(), - registry=None + registry=None, + connections={}, ) if not options.registry: @@ -96,13 +98,12 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): '{}.Meta, received "{}".' ).format(name, options.model) - cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) options.registry.register(cls) options.sqlalchemy_fields = yank_fields_from_attrs( - construct_fields(options), + construct_fields(options, name), _as=Field, ) options.fields = merge( diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 17d9854e..fc5c5133 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -2,17 +2,18 @@ import re from collections import Iterable, OrderedDict from functools import partial +from promise import Promise import six from graphql_relay import connection_from_list -from ..types import Boolean, Int, List, String, AbstractType +from ..types import Boolean, Int, List, String, AbstractType, Dynamic from ..types.field import Field from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.options import Options from ..utils.is_base_type import is_base_type from ..utils.props import props -from .node import Node, is_node +from .node import Node class PageInfo(ObjectType): @@ -55,47 +56,79 @@ class ConnectionMeta(ObjectTypeMeta): ) options.interfaces = () options.local_fields = OrderedDict() - - assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) - assert issubclass(options.node, (Node, ObjectType)), ( - 'Received incompatible node "{}" for Connection {}.' - ).format(options.node, name) - base_name = re.sub('Connection$', '', name) + + if attrs.get('edges'): + edges = attrs.get('edges') + edge = edges.of_type + else: + assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) + assert issubclass(options.node, (Node, ObjectType)), ( + 'Received incompatible node "{}" for Connection {}.' + ).format(options.node, name) + + base_name = re.sub('Connection$', '', name) + + edge_class = attrs.pop('Edge', None) + + edge_attrs = { + 'node': Field( + options.node, description='The item at the end of the edge'), + 'cursor': Edge._meta.fields['cursor'] + } + + edge_name = '{}Edge'.format(base_name) + if edge_class and issubclass(edge_class, AbstractType): + edge = type(edge_name, (edge_class, ObjectType, ), edge_attrs) + else: + additional_attrs = props(edge_class) if edge_class else {} + edge_attrs.update(additional_attrs) + edge = type(edge_name, (ObjectType, ), edge_attrs) + + edges = List(edge) + if not options.name: options.name = '{}Connection'.format(base_name) - edge_class = attrs.pop('Edge', None) - - class EdgeBase(AbstractType): - node = Field(options.node, description='The item at the end of the edge') - cursor = String(required=True, description='A cursor for use in pagination') - - edge_name = '{}Edge'.format(base_name) - if edge_class and issubclass(edge_class, AbstractType): - edge = type(edge_name, (EdgeBase, edge_class, ObjectType, ), {}) - else: - edge_attrs = props(edge_class) if edge_class else {} - edge = type(edge_name, (EdgeBase, ObjectType, ), edge_attrs) - - class ConnectionBase(AbstractType): - page_info = Field(PageInfo, name='pageInfo', required=True) - edges = List(edge) - - bases = (ConnectionBase, ) + bases + attrs.update({ + 'page_info': Field(PageInfo, name='pageInfo', required=True), + 'edges': edges, + }) attrs = dict(attrs, _meta=options, Edge=edge) return ObjectTypeMeta.__new__(cls, name, bases, attrs) class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): - pass + + @classmethod + def for_type(cls, gql_type): + connection_name = '{}Connection'.format(gql_type._meta.name) + + class Meta(object): + node = gql_type + + return type(connection_name, (Connection, ), {'Meta': Meta}) + + +class Edge(AbstractType): + cursor = String(required=True, description='A cursor for use in pagination') + + +def is_connection(gql_type): + '''Checks if a type is a connection. Taken directly from the spec definition: + https://facebook.github.io/relay/graphql/connections.htm#sec-Connection-Types''' + return gql_type._meta.name.endswith('Connection') if hasattr(gql_type, '_meta') else False class IterableConnectionField(Field): - def __init__(self, type, *args, **kwargs): + def __init__(self, gql_type, *args, **kwargs): + assert is_connection(gql_type) or isinstance(gql_type, Dynamic), ( + 'The provided type "{}" for this ConnectionField has to be a Connection as defined by the Relay' + ' spec.'.format(gql_type) + ) super(IterableConnectionField, self).__init__( - type, + gql_type, *args, before=String(), after=String(), @@ -106,32 +139,39 @@ class IterableConnectionField(Field): @property def type(self): - type = super(IterableConnectionField, self).type - if is_node(type): - connection_type = type.Connection + gql_type = super(IterableConnectionField, self).type + if isinstance(gql_type, Dynamic): + return gql_type.get_type() else: - connection_type = type - assert issubclass(connection_type, Connection), ( - '{} type have to be a subclass of Connection. Received "{}".' - ).format(str(self), connection_type) - return connection_type + return gql_type @staticmethod def connection_resolver(resolver, connection, root, args, context, info): - iterable = resolver(root, args, context, info) - assert isinstance(iterable, Iterable), ( - 'Resolved value from the connection field have to be iterable. ' - 'Received "{}"' - ).format(iterable) - connection = connection_from_list( - iterable, - args, - connection_type=connection, - edge_type=connection.Edge, - pageinfo_type=PageInfo - ) - connection.iterable = iterable - return connection + resolved = Promise.resolve(resolver(root, args, context, info)) + + def handle_connection_and_list(result): + if isinstance(result, connection): + return result + elif is_connection(result): + raise AssertionError('Resolved value from the connection field has to be a {}. ' + 'Received {}.'.format(connection, type(result))) + else: + assert isinstance(result, Iterable), ( + 'Resolved value from the connection field have to be iterable. ' + 'Received "{}"' + ).format(result) + + resolved_connection = connection_from_list( + result, + args, + connection_type=connection, + edge_type=connection.Edge, + pageinfo_type=PageInfo + ) + resolved_connection.iterable = result + return resolved_connection + + return resolved.then(handle_connection_and_list) def get_resolver(self, parent_resolver): resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) diff --git a/graphene/relay/node.py b/graphene/relay/node.py index f73f6a53..4b20f5b5 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -21,18 +21,6 @@ def is_node(objecttype): return False -def get_default_connection(cls): - from .connection import Connection - assert issubclass(cls, ObjectType), ( - 'Can only get connection type on implemented Nodes.' - ) - - class Meta: - node = cls - - return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta}) - - class GlobalID(Field): def __init__(self, node, *args, **kwargs): @@ -100,11 +88,3 @@ class Node(six.with_metaclass(NodeMeta, Interface)): @classmethod def to_global_id(cls, type, id): return to_global_id(type, id) - - @classmethod - def implements(cls, objecttype): - get_connection = getattr(objecttype, 'get_connection', None) - if not get_connection: - get_connection = partial(get_default_connection, objecttype) - - objecttype.Connection = get_connection() diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index 2f1441ac..8e7ec608 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -1,6 +1,5 @@ - from ...types import Field, List, NonNull, ObjectType, String, AbstractType -from ..connection import Connection, PageInfo +from ..connection import Connection, PageInfo, Edge from ..node import Node @@ -23,7 +22,52 @@ def test_connection(): assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields - assert list(fields.keys()) == ['page_info', 'edges', 'extra'] + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] + edge_field = fields['edges'] + pageinfo_field = fields['page_info'] + + assert isinstance(edge_field, Field) + assert isinstance(edge_field.type, List) + assert edge_field.type.of_type == MyObjectConnection.Edge + + assert isinstance(pageinfo_field, Field) + assert isinstance(pageinfo_field.type, NonNull) + assert pageinfo_field.type.of_type == PageInfo + + +def test_multiple_connection_edges_are_not_the_same(): + class MyObjectConnection(Connection): + extra = String() + + class Meta: + node = MyObject + + class Edge: + other = String() + + class MyOtherObjectConnection(Connection): + class Meta: + node = MyObject + + class Edge: + other = String() + + assert MyObjectConnection.Edge != MyOtherObjectConnection.Edge + assert MyObjectConnection.Edge._meta.name != MyOtherObjectConnection.Edge._meta.name + + +def test_create_connection_with_custom_edge_type(): + class MyEdge(Edge): + node = Field(MyObject) + + class MyObjectConnection(Connection): + extra = String() + edges = List(MyEdge) + + assert MyObjectConnection.Edge == MyEdge + assert MyObjectConnection._meta.name == 'MyObjectConnection' + fields = MyObjectConnection._meta.fields + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] edge_field = fields['edges'] pageinfo_field = fields['page_info'] @@ -46,7 +90,18 @@ def test_connection_inherit_abstracttype(): assert MyObjectConnection._meta.name == 'MyObjectConnection' fields = MyObjectConnection._meta.fields - assert list(fields.keys()) == ['page_info', 'edges', 'extra'] + assert list(fields.keys()) == ['extra', 'edges', 'page_info'] + + +def test_defaul_connection_for_type(): + MyObjectConnection = Connection.for_type(MyObject) + assert MyObjectConnection._meta.name == 'MyObjectConnection' + fields = MyObjectConnection._meta.fields + assert list(fields.keys()) == ['edges', 'page_info'] + + +def test_default_connection_for_type_does_not_returns_same_Connection(): + assert Connection.for_type(MyObject) != Connection.for_type(MyObject) def test_edge(): @@ -60,7 +115,7 @@ def test_edge(): Edge = MyObjectConnection.Edge assert Edge._meta.name == 'MyObjectEdge' edge_fields = Edge._meta.fields - assert list(edge_fields.keys()) == ['node', 'cursor', 'other'] + assert list(edge_fields.keys()) == ['cursor', 'other', 'node'] assert isinstance(edge_fields['node'], Field) assert edge_fields['node'].type == MyObject @@ -83,7 +138,7 @@ def test_edge_with_bases(): Edge = MyObjectConnection.Edge assert Edge._meta.name == 'MyObjectEdge' edge_fields = Edge._meta.fields - assert list(edge_fields.keys()) == ['node', 'cursor', 'extra', 'other'] + assert list(edge_fields.keys()) == ['extra', 'other', 'cursor', 'node'] assert isinstance(edge_fields['node'], Field) assert edge_fields['node'].type == MyObject @@ -92,17 +147,37 @@ def test_edge_with_bases(): assert edge_fields['other'].type == String -def test_edge_on_node(): - Edge = MyObject.Connection.Edge - assert Edge._meta.name == 'MyObjectEdge' - edge_fields = Edge._meta.fields - assert list(edge_fields.keys()) == ['node', 'cursor'] +def test_pageinfo(): + assert PageInfo._meta.name == 'PageInfo' + fields = PageInfo._meta.fields + assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor'] + + +def test_edge_for_node_type(): + edge = Connection.for_type(MyObject).Edge + + assert edge._meta.name == 'MyObjectEdge' + edge_fields = edge._meta.fields + assert list(edge_fields.keys()) == ['cursor', 'node'] assert isinstance(edge_fields['node'], Field) assert edge_fields['node'].type == MyObject -def test_pageinfo(): - assert PageInfo._meta.name == 'PageInfo' - fields = PageInfo._meta.fields - assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor'] +def test_edge_for_object_type(): + class MyObject(ObjectType): + field = String() + + edge = Connection.for_type(MyObject).Edge + + assert edge._meta.name == 'MyObjectEdge' + edge_fields = edge._meta.fields + assert list(edge_fields.keys()) == ['cursor', 'node'] + + assert isinstance(edge_fields['node'], Field) + assert edge_fields['node'].type == MyObject + + +def test_edge_for_type_returns_same_edge(): + MyObjectConnection = Connection.for_type(MyObject) + assert MyObjectConnection.Edge == MyObjectConnection.Edge diff --git a/graphene/relay/tests/test_connection_query.py b/graphene/relay/tests/test_connection_query.py index b2c10c31..47be34ee 100644 --- a/graphene/relay/tests/test_connection_query.py +++ b/graphene/relay/tests/test_connection_query.py @@ -1,6 +1,8 @@ from collections import OrderedDict -from ..connection import ConnectionField +from promise import Promise + +from ..connection import ConnectionField, Connection, PageInfo from ..node import Node from graphql_relay.utils import base64 from ...types import ObjectType, String, Schema @@ -15,12 +17,40 @@ class Letter(ObjectType): letter = String() -class Query(ObjectType): - letters = ConnectionField(Letter) +class MyLetterObjectConnection(Connection): + extra = String() - def resolve_letters(self, args, context, info): + class Meta: + node = Letter + + class Edge: + other = String() + +LetterConnection = Connection.for_type(Letter) + + +class Query(ObjectType): + letters = ConnectionField(LetterConnection) + letters_wrong_connection = ConnectionField(LetterConnection) + letters_promise = ConnectionField(LetterConnection) + letters_connection = ConnectionField(MyLetterObjectConnection) + + def resolve_letters(self, *_): return list(letters.values()) + def resolve_letters_wrong_connection(self, *_): + return MyLetterObjectConnection() + + def resolve_letters_connection(self, *_): + return MyLetterObjectConnection( + extra='1', + page_info=PageInfo(has_next_page=True, has_previous_page=False), + edges=[MyLetterObjectConnection.Edge(cursor='1', node=Letter(letter='hello'))] + ) + + def resolve_letters_promise(self, *_): + return Promise.resolve(list(letters.values())) + node = Node.Field() @@ -75,8 +105,7 @@ def execute(args=''): ''' % args) -def check(args, letters, has_previous_page=False, has_next_page=False): - result = execute(args) +def create_expexted_result(letters, has_previous_page=False, has_next_page=False, field_name='letters'): expected_edges = edges(letters) expected_page_info = { 'hasPreviousPage': has_previous_page, @@ -84,16 +113,107 @@ def check(args, letters, has_previous_page=False, has_next_page=False): 'endCursor': expected_edges[-1]['cursor'] if expected_edges else None, 'startCursor': expected_edges[0]['cursor'] if expected_edges else None } - - assert not result.errors - assert result.data == { - 'letters': { + return { + field_name: { 'edges': expected_edges, 'pageInfo': expected_page_info } } +def check(args, letters, has_previous_page=False, has_next_page=False): + result = execute(args) + assert not result.errors + assert result.data == create_expexted_result(letters, has_previous_page, has_next_page) + + +def test_resolver_throws_error_on_returning_wrong_connection_type(): + result = schema.execute(''' + { + lettersWrongConnection { + edges { + node { + id + } + } + } + } + ''') + assert result.errors[0].message == ('Resolved value from the connection field has to be a LetterConnection. ' + 'Received MyLetterObjectConnection.') + + +def test_resolver_handles_returned_connection_field_correctly(): + result = schema.execute(''' + { + lettersConnection { + extra + edges { + node { + id + letter + } + cursor + } + pageInfo { + hasPreviousPage + hasNextPage + startCursor + endCursor + } + } + } + ''') + + assert not result.errors + expected_result = { + 'lettersConnection': { + 'extra': '1', + 'edges': [ + { + 'node': { + 'id': 'TGV0dGVyOk5vbmU=', + 'letter': 'hello', + }, + 'cursor': '1' + } + ], + 'pageInfo': { + 'hasPreviousPage': False, + 'hasNextPage': True, + 'startCursor': None, + 'endCursor': None, + } + } + } + assert result.data == expected_result + + +def test_resolver_handles_returned_promise_correctly(): + result = schema.execute(''' + { + lettersPromise { + edges { + node { + id + letter + } + cursor + } + pageInfo { + hasPreviousPage + hasNextPage + startCursor + endCursor + } + } + } + ''') + + assert not result.errors + assert result.data == create_expexted_result('ABCDE', field_name='lettersPromise') + + def test_returns_all_elements_without_filters(): check('', 'ABCDE') diff --git a/graphene/relay/tests/test_mutation.py b/graphene/relay/tests/test_mutation.py index f8502df0..396770a2 100644 --- a/graphene/relay/tests/test_mutation.py +++ b/graphene/relay/tests/test_mutation.py @@ -1,4 +1,3 @@ -from collections import OrderedDict import pytest from ...types import (Argument, Field, InputField, InputObjectType, ObjectType, @@ -21,6 +20,9 @@ class MyNode(ObjectType): name = String() +MyNodeConnection = Connection.for_type(MyNode) + + class SaySomething(ClientIDMutation): class Input: @@ -39,13 +41,13 @@ class OtherMutation(ClientIDMutation): additional_field = String() name = String() - my_node_edge = Field(MyNode.Connection.Edge) + my_node_edge = Field(MyNodeConnection.Edge) @classmethod def mutate_and_get_payload(cls, args, context, info): shared = args.get('shared', '') additionalField = args.get('additionalField', '') - edge_type = MyNode.Connection.Edge + edge_type = MyNodeConnection.Edge return OtherMutation(name=shared + additionalField, my_node_edge=edge_type( cursor='1', node=MyNode(name='name'))) diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index ef35f409..92963440 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -53,15 +53,6 @@ def test_node_good(): assert 'id' in MyNode._meta.fields -def test_node_get_connection(): - connection = MyNode.Connection - assert issubclass(connection, Connection) - - -def test_node_get_connection_dont_duplicate(): - assert MyNode.Connection == MyNode.Connection - - def test_node_query(): executed = schema.execute( '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) diff --git a/graphene/types/field.py b/graphene/types/field.py index b1122696..981f7934 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -16,7 +16,7 @@ def source_resolver(source, root, args, context, info): class Field(OrderedType): - def __init__(self, type, args=None, resolver=None, source=None, + def __init__(self, gql_type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=None, **extra_args): super(Field, self).__init__(_creation_counter=_creation_counter) @@ -28,10 +28,10 @@ class Field(OrderedType): ) if required: - type = NonNull(type) + gql_type = NonNull(gql_type) self.name = name - self._type = type + self._type = gql_type self.args = to_arguments(args or OrderedDict(), extra_args) if source: resolver = partial(source_resolver, source) diff --git a/graphene/types/interface.py b/graphene/types/interface.py index a67ce914..f4b7585a 100644 --- a/graphene/types/interface.py +++ b/graphene/types/interface.py @@ -52,7 +52,3 @@ class Interface(six.with_metaclass(InterfaceMeta)): def __init__(self, *args, **kwargs): raise Exception("An Interface cannot be intitialized") - - @classmethod - def implements(cls, objecttype): - pass diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 91b74c0f..9e8802fa 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -46,9 +46,6 @@ class ObjectTypeMeta(AbstractTypeMeta): cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) - for interface in options.interfaces: - interface.implements(cls) - return cls def __str__(cls): # noqa: N802 diff --git a/setup.py b/setup.py index dfd82742..b11a90bc 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ else: # machinery. builtins.__SETUP__ = True -version = __import__('graphene').get_version() +version = "1.0.beta-1" class PyTest(TestCommand): diff --git a/tox.ini b/tox.ini index 7dbcffa5..d18983aa 100644 --- a/tox.ini +++ b/tox.ini @@ -5,8 +5,8 @@ skipsdist = true [testenv] deps= pytest>=2.7.2 - graphql-core>=0.5.1 - graphql-relay>=0.4.3 + graphql-core>=1.0.dev + graphql-relay>=0.4.4 six blinker singledispatch