diff --git a/.travis.yml b/.travis.yml index b6996d24..ae23cd02 100644 --- a/.travis.yml +++ b/.travis.yml @@ -25,6 +25,7 @@ install: if [ "$TEST_TYPE" = build ]; then pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter pip install --download-cache $HOME/.cache/pip/ -e .[django] + pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy] pip install django==$DJANGO_VERSION python setup.py develop elif [ "$TEST_TYPE" = build_website ]; then diff --git a/graphene/contrib/sqlalchemy/__init__.py b/graphene/contrib/sqlalchemy/__init__.py index ffeac8e5..88509ba3 100644 --- a/graphene/contrib/sqlalchemy/__init__.py +++ b/graphene/contrib/sqlalchemy/__init__.py @@ -1,6 +1,5 @@ from graphene.contrib.sqlalchemy.types import ( SQLAlchemyObjectType, - SQLAlchemyInterface, SQLAlchemyNode ) from graphene.contrib.sqlalchemy.fields import ( @@ -8,5 +7,5 @@ from graphene.contrib.sqlalchemy.fields import ( SQLAlchemyModelField ) -__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyInterface', 'SQLAlchemyNode', +__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyNode', 'SQLAlchemyConnectionField', 'SQLAlchemyModelField'] diff --git a/graphene/contrib/sqlalchemy/converter.py b/graphene/contrib/sqlalchemy/converter.py index e29cff1f..4e609dca 100644 --- a/graphene/contrib/sqlalchemy/converter.py +++ b/graphene/contrib/sqlalchemy/converter.py @@ -1,8 +1,9 @@ -from sqlalchemy import types -from sqlalchemy.orm import interfaces from singledispatch import singledispatch -from ...core.types.scalars import Boolean, Float, ID, Int, String +from sqlalchemy import types +from sqlalchemy.orm import interfaces + +from ...core.types.scalars import ID, Boolean, Float, Int, String from .fields import ConnectionOrListField, SQLAlchemyModelField diff --git a/graphene/contrib/sqlalchemy/fields.py b/graphene/contrib/sqlalchemy/fields.py index 2083c191..c38ffcac 100644 --- a/graphene/contrib/sqlalchemy/fields.py +++ b/graphene/contrib/sqlalchemy/fields.py @@ -1,31 +1,32 @@ -from sqlalchemy.orm import Query from ...core.exceptions import SkipField from ...core.fields import Field from ...core.types.base import FieldType from ...core.types.definitions import List from ...relay import ConnectionField from ...relay.utils import is_node -from ...utils import LazyMap - from .utils import get_type_for_model class SQLAlchemyConnectionField(ConnectionField): - def wrap_resolved(self, value, instance, args, info): - if isinstance(value, Query): - return LazyMap(value, self.type) - return value + def __init__(self, *args, **kwargs): + self.session = kwargs.pop('session', None) + return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) + @property + def model(self): + return self.type._meta.model -class LazyListField(Field): + def get_session(self, args, info): + return self.session - def get_type(self, schema): - return List(self.type) + def get_query(self, resolved_query, args, info): + self.get_session(args, info) + return resolved_query - def resolver(self, instance, args, info): - resolved = super(LazyListField, self).resolver(instance, args, info) - return LazyMap(resolved, self.type) + def from_list(self, connection_type, resolved, args, info): + qs = self.get_query(resolved, args, info) + return super(SQLAlchemyConnectionField, self).from_list(connection_type, qs, args, info) class ConnectionOrListField(Field): @@ -38,7 +39,7 @@ class ConnectionOrListField(Field): if is_node(field_object_type): field = SQLAlchemyConnectionField(field_object_type) else: - field = LazyListField(field_object_type) + field = Field(List(field_object_type)) field.contribute_to_class(self.object_type, self.attname) return schema.T(field) diff --git a/graphene/contrib/sqlalchemy/options.py b/graphene/contrib/sqlalchemy/options.py index e1d57827..1d4b2a4f 100644 --- a/graphene/contrib/sqlalchemy/options.py +++ b/graphene/contrib/sqlalchemy/options.py @@ -1,37 +1,23 @@ -import inspect - -from sqlalchemy.ext.declarative.api import DeclarativeMeta - -from ...core.options import Options +from ...core.classtypes.objecttype import ObjectTypeOptions from ...relay.types import Node from ...relay.utils import is_node VALID_ATTRS = ('model', 'only_fields', 'exclude_fields') -def is_base(cls): - from graphene.contrib.sqlalchemy.types import SQLAlchemyObjectType - return SQLAlchemyObjectType in cls.__bases__ - - -class SQLAlchemyOptions(Options): +class SQLAlchemyOptions(ObjectTypeOptions): def __init__(self, *args, **kwargs): - self.model = None super(SQLAlchemyOptions, self).__init__(*args, **kwargs) + self.model = None self.valid_attrs += VALID_ATTRS self.only_fields = None self.exclude_fields = [] + self.filter_fields = None + self.filter_order_by = None def contribute_to_class(self, cls, name): super(SQLAlchemyOptions, self).contribute_to_class(cls, name) if is_node(cls): - self.exclude_fields = list(self.exclude_fields) + self.exclude_fields = list(self.exclude_fields) + ['id'] self.interfaces.append(Node) - if not is_node(cls) and not is_base(cls): - return - if not self.model: - raise Exception( - 'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % cls) - elif not inspect.isclass(self.model) or not isinstance(self.model, DeclarativeMeta): - raise Exception('Provided model in %s is not a SQLAlchemy model' % cls) diff --git a/graphene/contrib/sqlalchemy/tests/models.py b/graphene/contrib/sqlalchemy/tests/models.py index 8c97159e..ee021054 100644 --- a/graphene/contrib/sqlalchemy/tests/models.py +++ b/graphene/contrib/sqlalchemy/tests/models.py @@ -1,15 +1,14 @@ from __future__ import absolute_import -from sqlalchemy import Table, Column, Integer, String, Date, ForeignKey +from sqlalchemy import Column, Date, ForeignKey, Integer, String, Table from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship - Base = declarative_base() association_table = Table('association', Base.metadata, - Column('pet_id', Integer, ForeignKey('pets.id')), - Column('reporter_id', Integer, ForeignKey('reporters.id'))) + Column('pet_id', Integer, ForeignKey('pets.id')), + Column('reporter_id', Integer, ForeignKey('reporters.id'))) class Pet(Base): diff --git a/graphene/contrib/sqlalchemy/tests/test_converter.py b/graphene/contrib/sqlalchemy/tests/test_converter.py index 6404c466..e4cdaa6f 100644 --- a/graphene/contrib/sqlalchemy/tests/test_converter.py +++ b/graphene/contrib/sqlalchemy/tests/test_converter.py @@ -1,11 +1,13 @@ -from sqlalchemy import types, Column from py.test import raises import graphene -from graphene.contrib.sqlalchemy.converter import convert_sqlalchemy_column, convert_sqlalchemy_relationship -from graphene.contrib.sqlalchemy.fields import ConnectionOrListField, SQLAlchemyModelField +from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, + convert_sqlalchemy_relationship) +from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, + SQLAlchemyModelField) +from sqlalchemy import Column, types -from .models import Article, Reporter, Pet +from .models import Article, Pet, Reporter def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): @@ -72,7 +74,7 @@ def test_should_integer_convert_id(): def test_should_boolean_convert_boolean(): - field = assert_column_conversion(types.Boolean(), graphene.Boolean) + assert_column_conversion(types.Boolean(), graphene.Boolean) def test_should_float_convert_float(): diff --git a/graphene/contrib/sqlalchemy/tests/test_query.py b/graphene/contrib/sqlalchemy/tests/test_query.py index e981b146..0384aa8b 100644 --- a/graphene/contrib/sqlalchemy/tests/test_query.py +++ b/graphene/contrib/sqlalchemy/tests/test_query.py @@ -1,30 +1,44 @@ -from py.test import raises +import pytest import graphene -from graphene import relay -from graphene.contrib.sqlalchemy import SQLAlchemyNode, SQLAlchemyObjectType -from .models import Article, Reporter +from graphene.contrib.sqlalchemy import SQLAlchemyObjectType +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session, sessionmaker + +from .models import Base, Reporter + +db = create_engine('sqlite:///test_sqlalchemy.sqlite3') -def test_should_query_only_fields(): - with raises(Exception): - class ReporterType(SQLAlchemyObjectType): +@pytest.yield_fixture(scope='function') +def session(): + connection = db.engine.connect() + transaction = connection.begin() + Base.metadata.create_all(connection) - class Meta: - model = Reporter - only_fields = ('articles', ) + # options = dict(bind=connection, binds={}) + session_factory = sessionmaker(bind=connection) + session = scoped_session(session_factory) - schema = graphene.Schema(query=ReporterType) - query = ''' - query ReporterQuery { - articles - } - ''' - result = schema.execute(query) - assert not result.errors + yield session + + # Finalize test here + transaction.rollback() + connection.close() + session.remove() -def test_should_query_well(): +def setup_fixtures(session): + reporter = Reporter(first_name='ABA', last_name='X') + session.add(reporter) + reporter2 = Reporter(first_name='ABO', last_name='Y') + session.add(reporter2) + session.commit() + + +def test_should_query_well(session): + setup_fixtures(session) + class ReporterType(SQLAlchemyObjectType): class Meta: @@ -32,9 +46,13 @@ def test_should_query_well(): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) + reporters = ReporterType.List() def resolve_reporter(self, *args, **kwargs): - return ReporterType(Reporter(first_name='ABA', last_name='X')) + return session.query(Reporter).first() + + def resolve_reporters(self, *args, **kwargs): + return session.query(Reporter) query = ''' query ReporterQuery { @@ -43,6 +61,9 @@ def test_should_query_well(): lastName, email } + reporters { + firstName + } } ''' expected = { @@ -50,90 +71,12 @@ def test_should_query_well(): 'firstName': 'ABA', 'lastName': 'X', 'email': None - } - } - schema = graphene.Schema(query=Query) - result = schema.execute(query) - assert not result.errors - assert result.data == expected - - -def test_should_node(): - class ReporterNode(SQLAlchemyNode): - - class Meta: - model = Reporter - exclude_fields = ('id', ) - - @classmethod - def get_node(cls, id, info): - return ReporterNode(Reporter(id=2, first_name='Cookie Monster')) - - def resolve_articles(self, *args, **kwargs): - return [ArticleNode(Article(headline='Hi!'))] - - class ArticleNode(SQLAlchemyNode): - - class Meta: - model = Article - exclude_fields = ('id', ) - - @classmethod - def get_node(cls, id, info): - return ArticleNode(Article(id=1, headline='Article node')) - - class Query(graphene.ObjectType): - node = relay.NodeField() - reporter = graphene.Field(ReporterNode) - article = graphene.Field(ArticleNode) - - def resolve_reporter(self, *args, **kwargs): - return ReporterNode(Reporter(id=1, first_name='ABA', last_name='X')) - - query = ''' - query ReporterQuery { - reporter { - id, - firstName, - articles { - edges { - node { - headline - } - } - } - lastName, - email - } - myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") { - id - ... on ReporterNode { - firstName - } - ... on ArticleNode { - headline - } - } - } - ''' - expected = { - 'reporter': { - 'id': 'UmVwb3J0ZXJOb2RlOjE=', - 'firstName': 'ABA', - 'lastName': 'X', - 'email': None, - 'articles': { - 'edges': [{ - 'node': { - 'headline': 'Hi!' - } - }] - }, }, - 'myArticle': { - 'id': 'QXJ0aWNsZU5vZGU6MQ==', - 'headline': 'Article node' - } + 'reporters': [{ + 'firstName': 'ABA', + }, { + 'firstName': 'ABO', + }] } schema = graphene.Schema(query=Query) result = schema.execute(query) diff --git a/graphene/contrib/sqlalchemy/tests/test_types.py b/graphene/contrib/sqlalchemy/tests/test_types.py index f45fd447..feffbc74 100644 --- a/graphene/contrib/sqlalchemy/tests/test_types.py +++ b/graphene/contrib/sqlalchemy/tests/test_types.py @@ -1,8 +1,9 @@ -from graphql.core.type import GraphQLInterfaceType, GraphQLObjectType +from graphql.core.type import GraphQLObjectType from pytest import raises from graphene import Schema -from graphene.contrib.sqlalchemy.types import SQLAlchemyInterface, SQLAlchemyNode +from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode, + SQLAlchemyObjectType) from graphene.core.fields import Field from graphene.core.types.scalars import Int from graphene.relay.fields import GlobalIDField @@ -13,7 +14,7 @@ from .models import Article, Reporter schema = Schema() -class Character(SQLAlchemyInterface): +class Character(SQLAlchemyObjectType): '''Character description''' class Meta: model = Reporter @@ -31,23 +32,23 @@ class Human(SQLAlchemyNode): def test_sqlalchemy_interface(): - assert SQLAlchemyNode._meta.is_interface is True + assert SQLAlchemyNode._meta.interface is True -def test_sqlalchemy_get_node(get): - human = Human.get_node(1, None) - get.assert_called_with(id=1) - assert human.id == 1 +# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) +# def test_sqlalchemy_get_node(get): +# human = Human.get_node(1, None) +# get.assert_called_with(id=1) +# assert human.id == 1 -def test_pseudo_interface_registered(): +def test_objecttype_registered(): object_type = schema.T(Character) - assert Character._meta.is_interface is True - assert isinstance(object_type, GraphQLInterfaceType) + assert isinstance(object_type, GraphQLObjectType) assert Character._meta.model == Reporter assert_equal_lists( object_type.get_fields().keys(), - ['articles', 'firstName', 'lastName', 'email', 'pets', 'id'] + ['articles', 'firstName', 'lastName', 'email', 'id'] ) @@ -67,11 +68,6 @@ def test_node_replacedfield(): assert schema.T(idfield).type == schema.T(Int()) -def test_interface_resolve_type(): - resolve_type = Character.resolve_type(schema, Human()) - assert isinstance(resolve_type, GraphQLObjectType) - - def test_interface_objecttype_init_none(): h = Human() assert h._root is None @@ -92,7 +88,7 @@ def test_interface_objecttype_init_unexpected(): def test_object_type(): object_type = schema.T(Human) Human._meta.fields_map - assert Human._meta.is_interface is False + assert Human._meta.interface is False assert isinstance(object_type, GraphQLObjectType) assert_equal_lists( object_type.get_fields().keys(), @@ -102,5 +98,5 @@ def test_object_type(): def test_node_notinterface(): - assert Human._meta.is_interface is False + assert Human._meta.interface is False assert SQLAlchemyNode in Human._meta.interfaces diff --git a/graphene/contrib/sqlalchemy/types.py b/graphene/contrib/sqlalchemy/types.py index a96f7848..aed62625 100644 --- a/graphene/contrib/sqlalchemy/types.py +++ b/graphene/contrib/sqlalchemy/types.py @@ -1,26 +1,25 @@ -import six -from sqlalchemy.inspection import inspect +import inspect -from ...core.types import BaseObjectType, ObjectTypeMeta -from ...relay.fields import GlobalIDField -from ...relay.types import BaseNode -from .converter import convert_sqlalchemy_column, convert_sqlalchemy_relationship +import six + +from sqlalchemy.inspection import inspect as sqlalchemyinspect + +from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta +from ...relay.types import Connection, Node, NodeMeta +from .converter import (convert_sqlalchemy_column, + convert_sqlalchemy_relationship) from .options import SQLAlchemyOptions +from .utils import is_mapped class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): - options_cls = SQLAlchemyOptions + options_class = SQLAlchemyOptions - def is_interface(cls, parents): - return SQLAlchemyInterface in parents - - def add_extra_fields(cls): - if not cls._meta.model: - return + def construct_fields(cls): only_fields = cls._meta.only_fields exclude_fields = cls._meta.exclude_fields already_created_fields = {f.attname for f in cls._meta.local_fields} - inspected_model = inspect(cls._meta.model) + inspected_model = sqlalchemyinspect(cls._meta.model) # Get all the columns for the relationships on the model for relationship in inspected_model.relationships: @@ -45,8 +44,24 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): converted_column = convert_sqlalchemy_column(column) cls.add_to_class(column.name, converted_column) + def construct(cls, *args, **kwargs): + cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs) + if not cls._meta.abstract: + if not cls._meta.model: + raise Exception( + 'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % + cls) + elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model): + raise Exception('Provided model in %s is not a SQLAlchemy model' % cls) -class InstanceObjectType(BaseObjectType): + cls.construct_fields() + return cls + + +class InstanceObjectType(ObjectType): + + class Meta: + abstract = True def __init__(self, _root=None): if _root: @@ -71,16 +86,32 @@ class InstanceObjectType(BaseObjectType): return getattr(self._root, attr) -class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, InstanceObjectType)): +class SQLAlchemyObjectType(six.with_metaclass( + SQLAlchemyObjectTypeMeta, InstanceObjectType)): + + class Meta: + abstract = True + + +class SQLAlchemyConnection(Connection): pass -class SQLAlchemyInterface(six.with_metaclass(SQLAlchemyObjectTypeMeta, InstanceObjectType)): +class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta): pass -class SQLAlchemyNode(BaseNode, SQLAlchemyInterface): - id = GlobalIDField() +class NodeInstance(Node, InstanceObjectType): + + class Meta: + abstract = True + + +class SQLAlchemyNode(six.with_metaclass( + SQLAlchemyNodeMeta, NodeInstance)): + + class Meta: + abstract = True @classmethod def get_node(cls, id, info=None): diff --git a/graphene/contrib/sqlalchemy/utils.py b/graphene/contrib/sqlalchemy/utils.py index 48380ba1..8d8c0b27 100644 --- a/graphene/contrib/sqlalchemy/utils.py +++ b/graphene/contrib/sqlalchemy/utils.py @@ -1,3 +1,10 @@ +from sqlalchemy.ext.declarative.api import DeclarativeMeta + + +# from sqlalchemy.orm.base import object_mapper +# from sqlalchemy.orm.exc import UnmappedInstanceError + + def get_type_for_model(schema, model): schema = schema types = schema.types.values() @@ -6,3 +13,12 @@ def get_type_for_model(schema, model): _type._meta, 'model', None) if model == type_model: return _type + + +def is_mapped(obj): + return isinstance(obj, DeclarativeMeta) + # try: + # object_mapper(obj) + # except UnmappedInstanceError: + # return False + # return True diff --git a/setup.py b/setup.py index 937d55fd..9e67851f 100644 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ setup( 'django-filter>=0.10.0', 'pytest>=2.7.2', 'pytest-django', + 'sqlalchemy', 'mock', ], extras_require={ @@ -71,7 +72,7 @@ setup( 'graphql-django-view>=1.1.0', ], 'sqlalchemy': [ - 'SQLAlchemy' + 'sqlalchemy' ] },