diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index 8ad1a643..32813afd 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -16,6 +16,8 @@ class Ship(graphene.ObjectType): def get_node(cls, id, context, info): return get_ship(id) +ShipConnection = relay.Connection.for_type(Ship) + class Faction(graphene.ObjectType): '''A faction in the Star Wars saga''' @@ -24,7 +26,7 @@ class Faction(graphene.ObjectType): interfaces = (relay.Node, ) name = graphene.String(description='The name of the faction.') - ships = relay.ConnectionField(Ship, description='The ships used by the faction.') + ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.') @resolve_only_args def resolve_ships(self, **args): diff --git a/graphene-django/graphene_django/converter.py b/graphene-django/graphene_django/converter.py index dc1eb88b..650b3cfe 100644 --- a/graphene-django/graphene_django/converter.py +++ b/graphene-django/graphene_django/converter.py @@ -1,7 +1,7 @@ from django.db import models from django.utils.encoding import force_text -from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull, Field, Dynamic +from graphene import Enum, List, ID, Boolean, Float, Int, String, NonNull, Field, Dynamic from graphene.types.json import JSONString from graphene.types.datetime import DateTime from graphene.utils.str_converters import to_const @@ -37,7 +37,7 @@ def convert_django_field_with_choices(field, registry=None): name = '{}{}'.format(meta.object_name, field.name.capitalize()) choices = list(get_choices(choices)) named_choices = [(c[0], c[1]) for c in choices] - named_choices_descriptions = {c[0]:c[2] for c in choices} + named_choices_descriptions = {c[0]: c[2] for c in choices} class EnumWithDescriptionsType(object): @property diff --git a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py index df967dc4..5c8518ef 100644 --- a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py +++ b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py @@ -30,8 +30,10 @@ class Role(SQLAlchemyObjectType): class Query(graphene.ObjectType): node = relay.Node.Field() - all_employees = SQLAlchemyConnectionField(Employee) - all_roles = SQLAlchemyConnectionField(Role) + employee_connection = relay.Connection.for_type(Employee) + role_connection = relay.Connection.for_type(Role) + all_employees = SQLAlchemyConnectionField(employee_connection) + all_roles = SQLAlchemyConnectionField(role_connection) role = graphene.Field(Role) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py index 89f58177..6839a26f 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import interfaces from sqlalchemy.dialects import postgresql from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic -from graphene.relay import is_node +from graphene.relay import is_node, Connection from graphene.types.json import JSONString from .fields import SQLAlchemyConnectionField @@ -18,7 +18,7 @@ except ImportError: pass -def convert_sqlalchemy_relationship(relationship, registry): +def convert_sqlalchemy_relationship(relationship, registry, connections, type_name): direction = relationship.direction model = relationship.mapper.entity @@ -31,7 +31,12 @@ def convert_sqlalchemy_relationship(relationship, registry): elif (direction == interfaces.ONETOMANY or direction == interfaces.MANYTOMANY): if is_node(_type): - return SQLAlchemyConnectionField(_type) + try: + connection_type = connections[relationship.key] + except KeyError: + raise KeyError("No Connection provided for relationship {} on type {}. Specify it in its Meta " + "class on the 'connections' dict.".format(relationship.key, type_name)) + return SQLAlchemyConnectionField(connection_type) return Field(List(_type)) return Dynamic(dynamic_type) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py index 192f69e9..e2fcdf72 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py @@ -6,7 +6,7 @@ from sqlalchemy_utils import ChoiceType, ScalarListType from sqlalchemy.dialects import postgresql import graphene -from graphene.relay import Node +from graphene.relay import Node, Connection from graphene.types.json import JSONString from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_composite, @@ -129,7 +129,7 @@ def test_should_scalar_list_convert_list(): def test_should_manytomany_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry, {}, '') assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -139,7 +139,7 @@ def test_should_manytomany_convert_connectionorlist_list(): class Meta: model = Pet - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -153,14 +153,33 @@ def test_should_manytomany_convert_connectionorlist_connection(): model = Pet interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + connections = {'pets': Connection.for_type(A)} + + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A') assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) + assert issubclass(dynamic_field.get_type().type, Connection) + + +def test_should_rais_when_no_connections_is_provided_for_manyto_many(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + interfaces = (Node, ) + + connections = {} + + with raises(KeyError) as ctx: + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A') + dynamic_field.get_type() + + assert str(ctx.value) == ('\"No Connection provided for relationship pets on type A. Specify it in its Meta ' + 'class on the \'connections\' dict.\"') def test_should_manytoone_convert_connectionorlist(): registry = Registry() - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry, {}, '') assert isinstance(dynamic_field, graphene.Dynamic) assert not dynamic_field.get_type() @@ -170,7 +189,7 @@ def test_should_manytoone_convert_connectionorlist_list(): class Meta: model = Reporter - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -183,7 +202,7 @@ def test_should_manytoone_convert_connectionorlist_connection(): model = Reporter interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) @@ -196,7 +215,7 @@ def test_should_onetoone_convert_field(): model = Article interfaces = (Node, ) - dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry, {}, 'A') assert isinstance(dynamic_field, graphene.Dynamic) graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py index 8fa8d18e..a6caf135 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py @@ -3,7 +3,7 @@ from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker import graphene -from graphene.relay import Node +from graphene.relay import Node, Connection from ..types import SQLAlchemyObjectType from ..fields import SQLAlchemyConnectionField @@ -94,16 +94,6 @@ def test_should_query_well(session): def test_should_node(session): setup_fixtures(session) - class ReporterNode(SQLAlchemyObjectType): - - class Meta: - model = Reporter - interfaces = (Node, ) - - @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') - class ArticleNode(SQLAlchemyObjectType): class Meta: @@ -114,11 +104,26 @@ def test_should_node(session): # def get_node(cls, id, info): # return Article(id=1, headline='Article node') + ArticleNodeConnection = Connection.for_type(ArticleNode) + + class ReporterNode(SQLAlchemyObjectType): + + class Meta: + model = Reporter + connections = { + 'articles': ArticleNodeConnection + } + interfaces = (Node, ) + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name='Cookie Monster') + class Query(graphene.ObjectType): node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) - all_articles = SQLAlchemyConnectionField(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleNodeConnection) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -202,7 +207,8 @@ def test_should_custom_identifier(session): class Query(graphene.ObjectType): node = Node.Field() - all_editors = SQLAlchemyConnectionField(EditorNode) + EditorNodeConnection = Connection.for_type(EditorNode) + all_editors = SQLAlchemyConnectionField(EditorNodeConnection) query = ''' query EditorQuery { @@ -250,23 +256,27 @@ def test_should_mutate_well(session): model = Editor interfaces = (Node, ) - - class ReporterNode(SQLAlchemyObjectType): - - class Meta: - model = Reporter - interfaces = (Node, ) - - @classmethod - def get_node(cls, id, info): - return Reporter(id=2, first_name='Cookie Monster') - class ArticleNode(SQLAlchemyObjectType): class Meta: model = Article interfaces = (Node, ) + ArticleNodeConnection = Connection.for_type(ArticleNode) + + class ReporterNode(SQLAlchemyObjectType): + + class Meta: + model = Reporter + connections = { + 'articles': ArticleNodeConnection + } + interfaces = (Node, ) + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name='Cookie Monster') + class CreateArticle(graphene.Mutation): class Input: headline = graphene.String() @@ -279,7 +289,7 @@ def test_should_mutate_well(session): def mutate(cls, instance, args, context, info): new_article = Article( headline=args.get('headline'), - reporter_id = args.get('reporter_id'), + reporter_id=args.get('reporter_id'), ) session.add(new_article) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/types.py b/graphene-sqlalchemy/graphene_sqlalchemy/types.py index bade191f..18a3e52c 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/types.py @@ -18,7 +18,7 @@ from graphene.types.utils import yank_fields_from_attrs, merge from .utils import get_query -def construct_fields(options): +def construct_fields(options, type_name): only_fields = options.only_fields exclude_fields = options.exclude_fields inspected_model = sqlalchemyinspect(options.model) @@ -56,7 +56,7 @@ def construct_fields(options): # We skip this field if we specify only_fields and is not # in there. Or when we excldue this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) + converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry, options.connections, type_name) name = relationship.key fields[name] = converted_relationship @@ -82,7 +82,8 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): exclude_fields=(), id='id', interfaces=(), - registry=None + registry=None, + connections={}, ) if not options.registry: @@ -96,13 +97,12 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): '{}.Meta, received "{}".' ).format(name, options.model) - cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) options.registry.register(cls) options.sqlalchemy_fields = yank_fields_from_attrs( - construct_fields(options), + construct_fields(options, name), _as=Field, ) options.fields = merge( diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 21b58ce9..235a4f68 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -5,7 +5,6 @@ from functools import partial from promise import Promise import six -from fastcache import clru_cache from graphql_relay import connection_from_list from ..types import Boolean, Int, List, String, AbstractType @@ -102,7 +101,6 @@ class ConnectionMeta(ObjectTypeMeta): class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): @classmethod - @clru_cache(maxsize=None) def for_type(cls, gql_type): connection_name = '{}Connection'.format(gql_type._meta.name) @@ -125,6 +123,10 @@ def is_connection(gql_type): class IterableConnectionField(Field): def __init__(self, gql_type, *args, **kwargs): + assert is_connection(gql_type), ( + 'The provided type "{}" for this ConnectionField has to be a Connection as defined by the Relay' + ' spec.'.format(gql_type) + ) super(IterableConnectionField, self).__init__( gql_type, *args, @@ -134,19 +136,17 @@ class IterableConnectionField(Field): last=Int(), **kwargs ) - self._gql_type = gql_type - - @property - def type(self): - return self._gql_type if is_connection(self._gql_type) else Connection.for_type(self._gql_type) @staticmethod def connection_resolver(resolver, connection, root, args, context, info): resolved = Promise.resolve(resolver(root, args, context, info)) def handle_connection_and_list(result): - if is_connection(result): + if isinstance(result, connection): return result + elif is_connection(result): + raise AssertionError('Resolved value from the connection field has to be a {}. ' + 'Received {}.'.format(connection, type(result))) else: assert isinstance(result, Iterable), ( 'Resolved value from the connection field have to be iterable. ' diff --git a/graphene/relay/tests/test_connection.py b/graphene/relay/tests/test_connection.py index 94f0456b..8e7ec608 100644 --- a/graphene/relay/tests/test_connection.py +++ b/graphene/relay/tests/test_connection.py @@ -100,8 +100,8 @@ def test_defaul_connection_for_type(): assert list(fields.keys()) == ['edges', 'page_info'] -def test_defaul_connection_for_type_returns_same_Connection(): - assert Connection.for_type(MyObject) == Connection.for_type(MyObject) +def test_default_connection_for_type_does_not_returns_same_Connection(): + assert Connection.for_type(MyObject) != Connection.for_type(MyObject) def test_edge(): @@ -179,4 +179,5 @@ def test_edge_for_object_type(): def test_edge_for_type_returns_same_edge(): - assert Connection.for_type(MyObject).Edge == Connection.for_type(MyObject).Edge + MyObjectConnection = Connection.for_type(MyObject) + assert MyObjectConnection.Edge == MyObjectConnection.Edge diff --git a/graphene/relay/tests/test_connection_query.py b/graphene/relay/tests/test_connection_query.py index cd42f2d8..47be34ee 100644 --- a/graphene/relay/tests/test_connection_query.py +++ b/graphene/relay/tests/test_connection_query.py @@ -26,15 +26,21 @@ class MyLetterObjectConnection(Connection): class Edge: other = String() +LetterConnection = Connection.for_type(Letter) + class Query(ObjectType): - letters = ConnectionField(Letter) - letters_promise = ConnectionField(Letter) + letters = ConnectionField(LetterConnection) + letters_wrong_connection = ConnectionField(LetterConnection) + letters_promise = ConnectionField(LetterConnection) letters_connection = ConnectionField(MyLetterObjectConnection) def resolve_letters(self, *_): return list(letters.values()) + def resolve_letters_wrong_connection(self, *_): + return MyLetterObjectConnection() + def resolve_letters_connection(self, *_): return MyLetterObjectConnection( extra='1', @@ -121,6 +127,22 @@ def check(args, letters, has_previous_page=False, has_next_page=False): assert result.data == create_expexted_result(letters, has_previous_page, has_next_page) +def test_resolver_throws_error_on_returning_wrong_connection_type(): + result = schema.execute(''' + { + lettersWrongConnection { + edges { + node { + id + } + } + } + } + ''') + assert result.errors[0].message == ('Resolved value from the connection field has to be a LetterConnection. ' + 'Received MyLetterObjectConnection.') + + def test_resolver_handles_returned_connection_field_correctly(): result = schema.execute(''' { diff --git a/graphene/relay/tests/test_mutation.py b/graphene/relay/tests/test_mutation.py index 508284ba..396770a2 100644 --- a/graphene/relay/tests/test_mutation.py +++ b/graphene/relay/tests/test_mutation.py @@ -20,6 +20,9 @@ class MyNode(ObjectType): name = String() +MyNodeConnection = Connection.for_type(MyNode) + + class SaySomething(ClientIDMutation): class Input: @@ -38,13 +41,13 @@ class OtherMutation(ClientIDMutation): additional_field = String() name = String() - my_node_edge = Field(Connection.for_type(MyNode).Edge) + my_node_edge = Field(MyNodeConnection.Edge) @classmethod def mutate_and_get_payload(cls, args, context, info): shared = args.get('shared', '') additionalField = args.get('additionalField', '') - edge_type = Connection.for_type(MyNode).Edge + edge_type = MyNodeConnection.Edge return OtherMutation(name=shared + additionalField, my_node_edge=edge_type( cursor='1', node=MyNode(name='name'))) diff --git a/graphene/types/field.py b/graphene/types/field.py index b1122696..981f7934 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -16,7 +16,7 @@ def source_resolver(source, root, args, context, info): class Field(OrderedType): - def __init__(self, type, args=None, resolver=None, source=None, + def __init__(self, gql_type, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=None, **extra_args): super(Field, self).__init__(_creation_counter=_creation_counter) @@ -28,10 +28,10 @@ class Field(OrderedType): ) if required: - type = NonNull(type) + gql_type = NonNull(gql_type) self.name = name - self._type = type + self._type = gql_type self.args = to_arguments(args or OrderedDict(), extra_args) if source: resolver = partial(source_resolver, source) diff --git a/setup.py b/setup.py index c295642f..dfd82742 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,6 @@ setup( 'six>=1.10.0', 'graphql-core>=1.0.dev', 'graphql-relay>=0.4.4', - 'fastcache>=1.0.2', 'promise', ], tests_require=[ diff --git a/tox.ini b/tox.ini index 01528c65..d18983aa 100644 --- a/tox.ini +++ b/tox.ini @@ -7,7 +7,6 @@ deps= pytest>=2.7.2 graphql-core>=1.0.dev graphql-relay>=0.4.4 - fastcache>=1.0.2 six blinker singledispatch