diff --git a/graphene/contrib/sqlalchemy/tests/test_query.py b/graphene/contrib/sqlalchemy/tests/test_query.py index 0384aa8b..b47398d8 100644 --- a/graphene/contrib/sqlalchemy/tests/test_query.py +++ b/graphene/contrib/sqlalchemy/tests/test_query.py @@ -1,11 +1,12 @@ import pytest import graphene -from graphene.contrib.sqlalchemy import SQLAlchemyObjectType +from graphene import relay +from graphene.contrib.sqlalchemy import SQLAlchemyObjectType, SQLAlchemyNode from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker -from .models import Base, Reporter +from .models import Base, Reporter, Article db = create_engine('sqlite:///test_sqlalchemy.sqlite3') @@ -33,6 +34,8 @@ def setup_fixtures(session): session.add(reporter) reporter2 = Reporter(first_name='ABO', last_name='Y') session.add(reporter2) + article = Article(headline='Hi!') + session.add(article) session.commit() @@ -82,3 +85,89 @@ def test_should_query_well(session): result = schema.execute(query) assert not result.errors assert result.data == expected + + +def test_should_node(session): + setup_fixtures(session) + + class ReporterNode(SQLAlchemyNode): + + class Meta: + model = Reporter + + @classmethod + def get_node(cls, id, info): + return Reporter(id=2, first_name='Cookie Monster') + + def resolve_articles(self, *args, **kwargs): + return [Article(headline='Hi!')] + + class ArticleNode(SQLAlchemyNode): + + class Meta: + model = Article + + # @classmethod + # def get_node(cls, id, info): + # return Article(id=1, headline='Article node') + + class Query(graphene.ObjectType): + node = relay.NodeField() + reporter = graphene.Field(ReporterNode) + article = graphene.Field(ArticleNode) + + def resolve_reporter(self, *args, **kwargs): + return Reporter(id=1, first_name='ABA', last_name='X') + + def resolve_article(self, *args, **kwargs): + return Article(id=1, headline='Article node') + + query = ''' + query ReporterQuery { + reporter { + id, + firstName, + articles { + edges { + node { + headline + } + } + } + lastName, + email + } + myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { + id + ... on ReporterNode { + firstName + } + ... on ArticleNode { + headline + } + } + } + ''' + expected = { + 'reporter': { + 'id': 'UmVwb3J0ZXJOb2RlOjE=', + 'firstName': 'ABA', + 'lastName': 'X', + 'email': None, + 'articles': { + 'edges': [{ + 'node': { + 'headline': 'Hi!' + } + }] + }, + }, + 'myArticle': { + 'id': 'QXJ0aWNsZU5vZGU6MQ==', + 'headline': 'Hi!' + } + } + schema = graphene.Schema(query=Query, session=session) + result = schema.execute(query) + assert not result.errors + assert result.data == expected diff --git a/graphene/contrib/sqlalchemy/tests/test_utils.py b/graphene/contrib/sqlalchemy/tests/test_utils.py new file mode 100644 index 00000000..2874ffaa --- /dev/null +++ b/graphene/contrib/sqlalchemy/tests/test_utils.py @@ -0,0 +1,25 @@ +from graphene import Schema, ObjectType, String + +from ..utils import get_session + + +def test_get_session(): + session = 'My SQLAlchemy session' + schema = Schema(session=session) + + class Query(ObjectType): + x = String() + + def resolve_x(self, args, info): + return get_session(info) + + query = ''' + query ReporterQuery { + x + } + ''' + + schema = Schema(query=Query, session=session) + result = schema.execute(query) + assert not result.errors + assert result.data['x'] == session diff --git a/graphene/contrib/sqlalchemy/types.py b/graphene/contrib/sqlalchemy/types.py index aed62625..1d55d80a 100644 --- a/graphene/contrib/sqlalchemy/types.py +++ b/graphene/contrib/sqlalchemy/types.py @@ -3,13 +3,14 @@ import inspect import six from sqlalchemy.inspection import inspect as sqlalchemyinspect +from sqlalchemy.orm.exc import NoResultFound from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta from ...relay.types import Connection, Node, NodeMeta from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) from .options import SQLAlchemyOptions -from .utils import is_mapped +from .utils import is_mapped, get_session class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): @@ -116,7 +117,9 @@ class SQLAlchemyNode(six.with_metaclass( @classmethod def get_node(cls, id, info=None): try: - instance = cls._meta.model.filter(id=id).one() + model = cls._meta.model + session = get_session(info) + instance = session.query(model).filter(model.id == id).one() return cls(instance) - except cls._meta.model.DoesNotExist: + except NoResultFound: return None diff --git a/graphene/contrib/sqlalchemy/utils.py b/graphene/contrib/sqlalchemy/utils.py index 8d8c0b27..18093d0e 100644 --- a/graphene/contrib/sqlalchemy/utils.py +++ b/graphene/contrib/sqlalchemy/utils.py @@ -15,6 +15,11 @@ def get_type_for_model(schema, model): return _type +def get_session(info): + schema = info.schema.graphene_schema + return schema.options.get('session') + + def is_mapped(obj): return isinstance(obj, DeclarativeMeta) # try: diff --git a/graphene/core/schema.py b/graphene/core/schema.py index c8695317..78653c14 100644 --- a/graphene/core/schema.py +++ b/graphene/core/schema.py @@ -26,7 +26,7 @@ class Schema(object): _executor = None def __init__(self, query=None, mutation=None, subscription=None, - name='Schema', executor=None, plugins=None, auto_camelcase=True): + name='Schema', executor=None, plugins=None, auto_camelcase=True, **options): self._types_names = {} self._types = {} self.mutation = mutation @@ -38,6 +38,7 @@ class Schema(object): if auto_camelcase: plugins.append(CamelCase()) self.plugins = PluginManager(self, plugins) + self.options = options signals.init_schema.send(self) def __repr__(self):