Improved query support

This commit is contained in:
Syrus Akbary 2016-01-22 19:21:50 -08:00
parent 017f6ae2a1
commit 6f7e00af95
5 changed files with 129 additions and 6 deletions

View File

@ -1,11 +1,12 @@
import pytest
import graphene
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType
from graphene import relay
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from .models import Base, Reporter
from .models import Base, Reporter, Article
db = create_engine('sqlite:///test_sqlalchemy.sqlite3')
@ -33,6 +34,8 @@ def setup_fixtures(session):
session.add(reporter)
reporter2 = Reporter(first_name='ABO', last_name='Y')
session.add(reporter2)
article = Article(headline='Hi!')
session.add(article)
session.commit()
@ -82,3 +85,89 @@ def test_should_query_well(session):
result = schema.execute(query)
assert not result.errors
assert result.data == expected
def test_should_node(session):
setup_fixtures(session)
class ReporterNode(SQLAlchemyNode):
class Meta:
model = Reporter
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
def resolve_articles(self, *args, **kwargs):
return [Article(headline='Hi!')]
class ArticleNode(SQLAlchemyNode):
class Meta:
model = Article
# @classmethod
# def get_node(cls, id, info):
# return Article(id=1, headline='Article node')
class Query(graphene.ObjectType):
node = relay.NodeField()
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
def resolve_reporter(self, *args, **kwargs):
return Reporter(id=1, first_name='ABA', last_name='X')
def resolve_article(self, *args, **kwargs):
return Article(id=1, headline='Article node')
query = '''
query ReporterQuery {
reporter {
id,
firstName,
articles {
edges {
node {
headline
}
}
}
lastName,
email
}
myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") {
id
... on ReporterNode {
firstName
}
... on ArticleNode {
headline
}
}
}
'''
expected = {
'reporter': {
'id': 'UmVwb3J0ZXJOb2RlOjE=',
'firstName': 'ABA',
'lastName': 'X',
'email': None,
'articles': {
'edges': [{
'node': {
'headline': 'Hi!'
}
}]
},
},
'myArticle': {
'id': 'QXJ0aWNsZU5vZGU6MQ==',
'headline': 'Hi!'
}
}
schema = graphene.Schema(query=Query, session=session)
result = schema.execute(query)
assert not result.errors
assert result.data == expected

View File

@ -0,0 +1,25 @@
from graphene import Schema, ObjectType, String
from ..utils import get_session
def test_get_session():
session = 'My SQLAlchemy session'
schema = Schema(session=session)
class Query(ObjectType):
x = String()
def resolve_x(self, args, info):
return get_session(info)
query = '''
query ReporterQuery {
x
}
'''
schema = Schema(query=Query, session=session)
result = schema.execute(query)
assert not result.errors
assert result.data['x'] == session

View File

@ -3,13 +3,14 @@ import inspect
import six
from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
from ...relay.types import Connection, Node, NodeMeta
from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship)
from .options import SQLAlchemyOptions
from .utils import is_mapped
from .utils import is_mapped, get_session
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
@ -116,7 +117,9 @@ class SQLAlchemyNode(six.with_metaclass(
@classmethod
def get_node(cls, id, info=None):
try:
instance = cls._meta.model.filter(id=id).one()
model = cls._meta.model
session = get_session(info)
instance = session.query(model).filter(model.id == id).one()
return cls(instance)
except cls._meta.model.DoesNotExist:
except NoResultFound:
return None

View File

@ -15,6 +15,11 @@ def get_type_for_model(schema, model):
return _type
def get_session(info):
schema = info.schema.graphene_schema
return schema.options.get('session')
def is_mapped(obj):
return isinstance(obj, DeclarativeMeta)
# try:

View File

@ -26,7 +26,7 @@ class Schema(object):
_executor = None
def __init__(self, query=None, mutation=None, subscription=None,
name='Schema', executor=None, plugins=None, auto_camelcase=True):
name='Schema', executor=None, plugins=None, auto_camelcase=True, **options):
self._types_names = {}
self._types = {}
self.mutation = mutation
@ -38,6 +38,7 @@ class Schema(object):
if auto_camelcase:
plugins.append(CamelCase())
self.plugins = PluginManager(self, plugins)
self.options = options
signals.init_schema.send(self)
def __repr__(self):