mirror of
https://github.com/graphql-python/graphene.git
synced 2024-11-22 17:46:57 +03:00
Updated SQLAlchemy integration in graphene
This commit is contained in:
parent
961cb1ad83
commit
017f6ae2a1
|
@ -25,6 +25,7 @@ install:
|
|||
if [ "$TEST_TYPE" = build ]; then
|
||||
pip install --download-cache $HOME/.cache/pip/ pytest pytest-cov coveralls six pytest-django django-filter
|
||||
pip install --download-cache $HOME/.cache/pip/ -e .[django]
|
||||
pip install --download-cache $HOME/.cache/pip/ -e .[sqlalchemy]
|
||||
pip install django==$DJANGO_VERSION
|
||||
python setup.py develop
|
||||
elif [ "$TEST_TYPE" = build_website ]; then
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from graphene.contrib.sqlalchemy.types import (
|
||||
SQLAlchemyObjectType,
|
||||
SQLAlchemyInterface,
|
||||
SQLAlchemyNode
|
||||
)
|
||||
from graphene.contrib.sqlalchemy.fields import (
|
||||
|
@ -8,5 +7,5 @@ from graphene.contrib.sqlalchemy.fields import (
|
|||
SQLAlchemyModelField
|
||||
)
|
||||
|
||||
__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyInterface', 'SQLAlchemyNode',
|
||||
__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyNode',
|
||||
'SQLAlchemyConnectionField', 'SQLAlchemyModelField']
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from sqlalchemy import types
|
||||
from sqlalchemy.orm import interfaces
|
||||
from singledispatch import singledispatch
|
||||
|
||||
from ...core.types.scalars import Boolean, Float, ID, Int, String
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.orm import interfaces
|
||||
|
||||
from ...core.types.scalars import ID, Boolean, Float, Int, String
|
||||
from .fields import ConnectionOrListField, SQLAlchemyModelField
|
||||
|
||||
|
||||
|
|
|
@ -1,31 +1,32 @@
|
|||
from sqlalchemy.orm import Query
|
||||
from ...core.exceptions import SkipField
|
||||
from ...core.fields import Field
|
||||
from ...core.types.base import FieldType
|
||||
from ...core.types.definitions import List
|
||||
from ...relay import ConnectionField
|
||||
from ...relay.utils import is_node
|
||||
from ...utils import LazyMap
|
||||
|
||||
from .utils import get_type_for_model
|
||||
|
||||
|
||||
class SQLAlchemyConnectionField(ConnectionField):
|
||||
|
||||
def wrap_resolved(self, value, instance, args, info):
|
||||
if isinstance(value, Query):
|
||||
return LazyMap(value, self.type)
|
||||
return value
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.session = kwargs.pop('session', None)
|
||||
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.type._meta.model
|
||||
|
||||
class LazyListField(Field):
|
||||
def get_session(self, args, info):
|
||||
return self.session
|
||||
|
||||
def get_type(self, schema):
|
||||
return List(self.type)
|
||||
def get_query(self, resolved_query, args, info):
|
||||
self.get_session(args, info)
|
||||
return resolved_query
|
||||
|
||||
def resolver(self, instance, args, info):
|
||||
resolved = super(LazyListField, self).resolver(instance, args, info)
|
||||
return LazyMap(resolved, self.type)
|
||||
def from_list(self, connection_type, resolved, args, info):
|
||||
qs = self.get_query(resolved, args, info)
|
||||
return super(SQLAlchemyConnectionField, self).from_list(connection_type, qs, args, info)
|
||||
|
||||
|
||||
class ConnectionOrListField(Field):
|
||||
|
@ -38,7 +39,7 @@ class ConnectionOrListField(Field):
|
|||
if is_node(field_object_type):
|
||||
field = SQLAlchemyConnectionField(field_object_type)
|
||||
else:
|
||||
field = LazyListField(field_object_type)
|
||||
field = Field(List(field_object_type))
|
||||
field.contribute_to_class(self.object_type, self.attname)
|
||||
return schema.T(field)
|
||||
|
||||
|
|
|
@ -1,37 +1,23 @@
|
|||
import inspect
|
||||
|
||||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||
|
||||
from ...core.options import Options
|
||||
from ...core.classtypes.objecttype import ObjectTypeOptions
|
||||
from ...relay.types import Node
|
||||
from ...relay.utils import is_node
|
||||
|
||||
VALID_ATTRS = ('model', 'only_fields', 'exclude_fields')
|
||||
|
||||
|
||||
def is_base(cls):
|
||||
from graphene.contrib.sqlalchemy.types import SQLAlchemyObjectType
|
||||
return SQLAlchemyObjectType in cls.__bases__
|
||||
|
||||
|
||||
class SQLAlchemyOptions(Options):
|
||||
class SQLAlchemyOptions(ObjectTypeOptions):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.model = None
|
||||
super(SQLAlchemyOptions, self).__init__(*args, **kwargs)
|
||||
self.model = None
|
||||
self.valid_attrs += VALID_ATTRS
|
||||
self.only_fields = None
|
||||
self.exclude_fields = []
|
||||
self.filter_fields = None
|
||||
self.filter_order_by = None
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
super(SQLAlchemyOptions, self).contribute_to_class(cls, name)
|
||||
if is_node(cls):
|
||||
self.exclude_fields = list(self.exclude_fields)
|
||||
self.exclude_fields = list(self.exclude_fields) + ['id']
|
||||
self.interfaces.append(Node)
|
||||
if not is_node(cls) and not is_base(cls):
|
||||
return
|
||||
if not self.model:
|
||||
raise Exception(
|
||||
'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % cls)
|
||||
elif not inspect.isclass(self.model) or not isinstance(self.model, DeclarativeMeta):
|
||||
raise Exception('Provided model in %s is not a SQLAlchemy model' % cls)
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from sqlalchemy import Table, Column, Integer, String, Date, ForeignKey
|
||||
from sqlalchemy import Column, Date, ForeignKey, Integer, String, Table
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
association_table = Table('association', Base.metadata,
|
||||
Column('pet_id', Integer, ForeignKey('pets.id')),
|
||||
Column('reporter_id', Integer, ForeignKey('reporters.id')))
|
||||
Column('pet_id', Integer, ForeignKey('pets.id')),
|
||||
Column('reporter_id', Integer, ForeignKey('reporters.id')))
|
||||
|
||||
|
||||
class Pet(Base):
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from sqlalchemy import types, Column
|
||||
from py.test import raises
|
||||
|
||||
import graphene
|
||||
from graphene.contrib.sqlalchemy.converter import convert_sqlalchemy_column, convert_sqlalchemy_relationship
|
||||
from graphene.contrib.sqlalchemy.fields import ConnectionOrListField, SQLAlchemyModelField
|
||||
from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField,
|
||||
SQLAlchemyModelField)
|
||||
from sqlalchemy import Column, types
|
||||
|
||||
from .models import Article, Reporter, Pet
|
||||
from .models import Article, Pet, Reporter
|
||||
|
||||
|
||||
def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs):
|
||||
|
@ -72,7 +74,7 @@ def test_should_integer_convert_id():
|
|||
|
||||
|
||||
def test_should_boolean_convert_boolean():
|
||||
field = assert_column_conversion(types.Boolean(), graphene.Boolean)
|
||||
assert_column_conversion(types.Boolean(), graphene.Boolean)
|
||||
|
||||
|
||||
def test_should_float_convert_float():
|
||||
|
|
|
@ -1,30 +1,44 @@
|
|||
from py.test import raises
|
||||
import pytest
|
||||
|
||||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.sqlalchemy import SQLAlchemyNode, SQLAlchemyObjectType
|
||||
from .models import Article, Reporter
|
||||
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
|
||||
from .models import Base, Reporter
|
||||
|
||||
db = create_engine('sqlite:///test_sqlalchemy.sqlite3')
|
||||
|
||||
|
||||
def test_should_query_only_fields():
|
||||
with raises(Exception):
|
||||
class ReporterType(SQLAlchemyObjectType):
|
||||
@pytest.yield_fixture(scope='function')
|
||||
def session():
|
||||
connection = db.engine.connect()
|
||||
transaction = connection.begin()
|
||||
Base.metadata.create_all(connection)
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
only_fields = ('articles', )
|
||||
# options = dict(bind=connection, binds={})
|
||||
session_factory = sessionmaker(bind=connection)
|
||||
session = scoped_session(session_factory)
|
||||
|
||||
schema = graphene.Schema(query=ReporterType)
|
||||
query = '''
|
||||
query ReporterQuery {
|
||||
articles
|
||||
}
|
||||
'''
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
yield session
|
||||
|
||||
# Finalize test here
|
||||
transaction.rollback()
|
||||
connection.close()
|
||||
session.remove()
|
||||
|
||||
|
||||
def test_should_query_well():
|
||||
def setup_fixtures(session):
|
||||
reporter = Reporter(first_name='ABA', last_name='X')
|
||||
session.add(reporter)
|
||||
reporter2 = Reporter(first_name='ABO', last_name='Y')
|
||||
session.add(reporter2)
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_should_query_well(session):
|
||||
setup_fixtures(session)
|
||||
|
||||
class ReporterType(SQLAlchemyObjectType):
|
||||
|
||||
class Meta:
|
||||
|
@ -32,9 +46,13 @@ def test_should_query_well():
|
|||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
reporters = ReporterType.List()
|
||||
|
||||
def resolve_reporter(self, *args, **kwargs):
|
||||
return ReporterType(Reporter(first_name='ABA', last_name='X'))
|
||||
return session.query(Reporter).first()
|
||||
|
||||
def resolve_reporters(self, *args, **kwargs):
|
||||
return session.query(Reporter)
|
||||
|
||||
query = '''
|
||||
query ReporterQuery {
|
||||
|
@ -43,6 +61,9 @@ def test_should_query_well():
|
|||
lastName,
|
||||
email
|
||||
}
|
||||
reporters {
|
||||
firstName
|
||||
}
|
||||
}
|
||||
'''
|
||||
expected = {
|
||||
|
@ -50,90 +71,12 @@ def test_should_query_well():
|
|||
'firstName': 'ABA',
|
||||
'lastName': 'X',
|
||||
'email': None
|
||||
}
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query)
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
||||
def test_should_node():
|
||||
class ReporterNode(SQLAlchemyNode):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
exclude_fields = ('id', )
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, id, info):
|
||||
return ReporterNode(Reporter(id=2, first_name='Cookie Monster'))
|
||||
|
||||
def resolve_articles(self, *args, **kwargs):
|
||||
return [ArticleNode(Article(headline='Hi!'))]
|
||||
|
||||
class ArticleNode(SQLAlchemyNode):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
exclude_fields = ('id', )
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, id, info):
|
||||
return ArticleNode(Article(id=1, headline='Article node'))
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
node = relay.NodeField()
|
||||
reporter = graphene.Field(ReporterNode)
|
||||
article = graphene.Field(ArticleNode)
|
||||
|
||||
def resolve_reporter(self, *args, **kwargs):
|
||||
return ReporterNode(Reporter(id=1, first_name='ABA', last_name='X'))
|
||||
|
||||
query = '''
|
||||
query ReporterQuery {
|
||||
reporter {
|
||||
id,
|
||||
firstName,
|
||||
articles {
|
||||
edges {
|
||||
node {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
lastName,
|
||||
email
|
||||
}
|
||||
myArticle: node(id:"QXJ0aWNsZU5vZGU6MQ==") {
|
||||
id
|
||||
... on ReporterNode {
|
||||
firstName
|
||||
}
|
||||
... on ArticleNode {
|
||||
headline
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
expected = {
|
||||
'reporter': {
|
||||
'id': 'UmVwb3J0ZXJOb2RlOjE=',
|
||||
'firstName': 'ABA',
|
||||
'lastName': 'X',
|
||||
'email': None,
|
||||
'articles': {
|
||||
'edges': [{
|
||||
'node': {
|
||||
'headline': 'Hi!'
|
||||
}
|
||||
}]
|
||||
},
|
||||
},
|
||||
'myArticle': {
|
||||
'id': 'QXJ0aWNsZU5vZGU6MQ==',
|
||||
'headline': 'Article node'
|
||||
}
|
||||
'reporters': [{
|
||||
'firstName': 'ABA',
|
||||
}, {
|
||||
'firstName': 'ABO',
|
||||
}]
|
||||
}
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from graphql.core.type import GraphQLInterfaceType, GraphQLObjectType
|
||||
from graphql.core.type import GraphQLObjectType
|
||||
from pytest import raises
|
||||
|
||||
from graphene import Schema
|
||||
from graphene.contrib.sqlalchemy.types import SQLAlchemyInterface, SQLAlchemyNode
|
||||
from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode,
|
||||
SQLAlchemyObjectType)
|
||||
from graphene.core.fields import Field
|
||||
from graphene.core.types.scalars import Int
|
||||
from graphene.relay.fields import GlobalIDField
|
||||
|
@ -13,7 +14,7 @@ from .models import Article, Reporter
|
|||
schema = Schema()
|
||||
|
||||
|
||||
class Character(SQLAlchemyInterface):
|
||||
class Character(SQLAlchemyObjectType):
|
||||
'''Character description'''
|
||||
class Meta:
|
||||
model = Reporter
|
||||
|
@ -31,23 +32,23 @@ class Human(SQLAlchemyNode):
|
|||
|
||||
|
||||
def test_sqlalchemy_interface():
|
||||
assert SQLAlchemyNode._meta.is_interface is True
|
||||
assert SQLAlchemyNode._meta.interface is True
|
||||
|
||||
|
||||
def test_sqlalchemy_get_node(get):
|
||||
human = Human.get_node(1, None)
|
||||
get.assert_called_with(id=1)
|
||||
assert human.id == 1
|
||||
# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
|
||||
# def test_sqlalchemy_get_node(get):
|
||||
# human = Human.get_node(1, None)
|
||||
# get.assert_called_with(id=1)
|
||||
# assert human.id == 1
|
||||
|
||||
|
||||
def test_pseudo_interface_registered():
|
||||
def test_objecttype_registered():
|
||||
object_type = schema.T(Character)
|
||||
assert Character._meta.is_interface is True
|
||||
assert isinstance(object_type, GraphQLInterfaceType)
|
||||
assert isinstance(object_type, GraphQLObjectType)
|
||||
assert Character._meta.model == Reporter
|
||||
assert_equal_lists(
|
||||
object_type.get_fields().keys(),
|
||||
['articles', 'firstName', 'lastName', 'email', 'pets', 'id']
|
||||
['articles', 'firstName', 'lastName', 'email', 'id']
|
||||
)
|
||||
|
||||
|
||||
|
@ -67,11 +68,6 @@ def test_node_replacedfield():
|
|||
assert schema.T(idfield).type == schema.T(Int())
|
||||
|
||||
|
||||
def test_interface_resolve_type():
|
||||
resolve_type = Character.resolve_type(schema, Human())
|
||||
assert isinstance(resolve_type, GraphQLObjectType)
|
||||
|
||||
|
||||
def test_interface_objecttype_init_none():
|
||||
h = Human()
|
||||
assert h._root is None
|
||||
|
@ -92,7 +88,7 @@ def test_interface_objecttype_init_unexpected():
|
|||
def test_object_type():
|
||||
object_type = schema.T(Human)
|
||||
Human._meta.fields_map
|
||||
assert Human._meta.is_interface is False
|
||||
assert Human._meta.interface is False
|
||||
assert isinstance(object_type, GraphQLObjectType)
|
||||
assert_equal_lists(
|
||||
object_type.get_fields().keys(),
|
||||
|
@ -102,5 +98,5 @@ def test_object_type():
|
|||
|
||||
|
||||
def test_node_notinterface():
|
||||
assert Human._meta.is_interface is False
|
||||
assert Human._meta.interface is False
|
||||
assert SQLAlchemyNode in Human._meta.interfaces
|
||||
|
|
|
@ -1,26 +1,25 @@
|
|||
import six
|
||||
from sqlalchemy.inspection import inspect
|
||||
import inspect
|
||||
|
||||
from ...core.types import BaseObjectType, ObjectTypeMeta
|
||||
from ...relay.fields import GlobalIDField
|
||||
from ...relay.types import BaseNode
|
||||
from .converter import convert_sqlalchemy_column, convert_sqlalchemy_relationship
|
||||
import six
|
||||
|
||||
from sqlalchemy.inspection import inspect as sqlalchemyinspect
|
||||
|
||||
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
|
||||
from ...relay.types import Connection, Node, NodeMeta
|
||||
from .converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from .options import SQLAlchemyOptions
|
||||
from .utils import is_mapped
|
||||
|
||||
|
||||
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
|
||||
options_cls = SQLAlchemyOptions
|
||||
options_class = SQLAlchemyOptions
|
||||
|
||||
def is_interface(cls, parents):
|
||||
return SQLAlchemyInterface in parents
|
||||
|
||||
def add_extra_fields(cls):
|
||||
if not cls._meta.model:
|
||||
return
|
||||
def construct_fields(cls):
|
||||
only_fields = cls._meta.only_fields
|
||||
exclude_fields = cls._meta.exclude_fields
|
||||
already_created_fields = {f.attname for f in cls._meta.local_fields}
|
||||
inspected_model = inspect(cls._meta.model)
|
||||
inspected_model = sqlalchemyinspect(cls._meta.model)
|
||||
|
||||
# Get all the columns for the relationships on the model
|
||||
for relationship in inspected_model.relationships:
|
||||
|
@ -45,8 +44,24 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
|
|||
converted_column = convert_sqlalchemy_column(column)
|
||||
cls.add_to_class(column.name, converted_column)
|
||||
|
||||
def construct(cls, *args, **kwargs):
|
||||
cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs)
|
||||
if not cls._meta.abstract:
|
||||
if not cls._meta.model:
|
||||
raise Exception(
|
||||
'SQLAlchemy ObjectType %s must have a model in the Meta class attr' %
|
||||
cls)
|
||||
elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model):
|
||||
raise Exception('Provided model in %s is not a SQLAlchemy model' % cls)
|
||||
|
||||
class InstanceObjectType(BaseObjectType):
|
||||
cls.construct_fields()
|
||||
return cls
|
||||
|
||||
|
||||
class InstanceObjectType(ObjectType):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def __init__(self, _root=None):
|
||||
if _root:
|
||||
|
@ -71,16 +86,32 @@ class InstanceObjectType(BaseObjectType):
|
|||
return getattr(self._root, attr)
|
||||
|
||||
|
||||
class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, InstanceObjectType)):
|
||||
class SQLAlchemyObjectType(six.with_metaclass(
|
||||
SQLAlchemyObjectTypeMeta, InstanceObjectType)):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class SQLAlchemyConnection(Connection):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAlchemyInterface(six.with_metaclass(SQLAlchemyObjectTypeMeta, InstanceObjectType)):
|
||||
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAlchemyNode(BaseNode, SQLAlchemyInterface):
|
||||
id = GlobalIDField()
|
||||
class NodeInstance(Node, InstanceObjectType):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class SQLAlchemyNode(six.with_metaclass(
|
||||
SQLAlchemyNodeMeta, NodeInstance)):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
@classmethod
|
||||
def get_node(cls, id, info=None):
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||
|
||||
|
||||
# from sqlalchemy.orm.base import object_mapper
|
||||
# from sqlalchemy.orm.exc import UnmappedInstanceError
|
||||
|
||||
|
||||
def get_type_for_model(schema, model):
|
||||
schema = schema
|
||||
types = schema.types.values()
|
||||
|
@ -6,3 +13,12 @@ def get_type_for_model(schema, model):
|
|||
_type._meta, 'model', None)
|
||||
if model == type_model:
|
||||
return _type
|
||||
|
||||
|
||||
def is_mapped(obj):
|
||||
return isinstance(obj, DeclarativeMeta)
|
||||
# try:
|
||||
# object_mapper(obj)
|
||||
# except UnmappedInstanceError:
|
||||
# return False
|
||||
# return True
|
||||
|
|
Loading…
Reference in New Issue
Block a user