diff --git a/graphene/contrib/sqlalchemy/fields.py b/graphene/contrib/sqlalchemy/fields.py index 99d045c4..b5f4e974 100644 --- a/graphene/contrib/sqlalchemy/fields.py +++ b/graphene/contrib/sqlalchemy/fields.py @@ -7,21 +7,24 @@ from ...relay.utils import is_node from .utils import get_type_for_model, maybe_query, get_query +class DefaultQuery(object): + pass + + class SQLAlchemyConnectionField(ConnectionField): def __init__(self, *args, **kwargs): + kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery) return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) @property def model(self): return self.type._meta.model - def get_query(self, resolved_query, args, info): - return resolved_query if resolved_query is not None else get_query(self.model, info) - def from_list(self, connection_type, resolved, args, info): - query = self.get_query(resolved, args, info) - query = maybe_query(query) + if resolved is DefaultQuery: + resolved = get_query(self.model, info) + query = maybe_query(resolved) return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info) diff --git a/graphene/contrib/sqlalchemy/tests/test_query.py b/graphene/contrib/sqlalchemy/tests/test_query.py index b47398d8..611ab2f6 100644 --- a/graphene/contrib/sqlalchemy/tests/test_query.py +++ b/graphene/contrib/sqlalchemy/tests/test_query.py @@ -2,7 +2,7 @@ import pytest import graphene from graphene import relay -from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode +from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode, SQLAlchemyConnectionField from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker @@ -115,6 +115,7 @@ def test_should_node(session): node = relay.NodeField() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) + all_articles = SQLAlchemyConnectionField(ArticleNode) def resolve_reporter(self, *args, **kwargs): return Reporter(id=1, first_name='ABA', last_name='X') @@ -137,6 +138,13 @@ def test_should_node(session): lastName, email } + allArticles { + edges { + node { + headline + } + } + } myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { id ... on ReporterNode { @@ -162,6 +170,13 @@ def test_should_node(session): }] }, }, + 'allArticles': { + 'edges': [{ + 'node': { + 'headline': 'Hi!' + } + }] + }, 'myArticle': { 'id': 'QXJ0aWNsZU5vZGU6MQ==', 'headline': 'Hi!'