Fixed default query value in SQLAlchemyConnectionField. Fixed #133

This commit is contained in:
Syrus Akbary 2016-03-30 00:12:59 -07:00
parent 21ec1163b9
commit 1f548f188d
2 changed files with 24 additions and 6 deletions

View File

@ -7,21 +7,24 @@ from ...relay.utils import is_node
from .utils import get_type_for_model, maybe_query, get_query from .utils import get_type_for_model, maybe_query, get_query
class DefaultQuery(object):
pass
class SQLAlchemyConnectionField(ConnectionField): class SQLAlchemyConnectionField(ConnectionField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery)
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
@property @property
def model(self): def model(self):
return self.type._meta.model return self.type._meta.model
def get_query(self, resolved_query, args, info):
return resolved_query if resolved_query is not None else get_query(self.model, info)
def from_list(self, connection_type, resolved, args, info): def from_list(self, connection_type, resolved, args, info):
query = self.get_query(resolved, args, info) if resolved is DefaultQuery:
query = maybe_query(query) resolved = get_query(self.model, info)
query = maybe_query(resolved)
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info) return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, info)

View File

@ -2,7 +2,7 @@ import pytest
import graphene import graphene
from graphene import relay from graphene import relay
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode, SQLAlchemyConnectionField
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
@ -115,6 +115,7 @@ def test_should_node(session):
node = relay.NodeField() node = relay.NodeField()
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return Reporter(id=1, first_name='ABA', last_name='X') return Reporter(id=1, first_name='ABA', last_name='X')
@ -137,6 +138,13 @@ def test_should_node(session):
lastName, lastName,
email email
} }
allArticles {
edges {
node {
headline
}
}
}
myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") {
id id
... on ReporterNode { ... on ReporterNode {
@ -162,6 +170,13 @@ def test_should_node(session):
}] }]
}, },
}, },
'allArticles': {
'edges': [{
'node': {
'headline': 'Hi!'
}
}]
},
'myArticle': { 'myArticle': {
'id': 'QXJ0aWNsZU5vZGU6MQ==', 'id': 'QXJ0aWNsZU5vZGU6MQ==',
'headline': 'Hi!' 'headline': 'Hi!'