First working version of graphene-sqlalchemy

This commit is contained in:
Syrus Akbary 2016-07-22 20:18:23 -07:00
parent 79d7636ab6
commit af4c63512c
15 changed files with 318 additions and 333 deletions

View File

@ -4,7 +4,6 @@ from django.utils.encoding import force_text
from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.datetime import DateTime from graphene.types.datetime import DateTime
from graphene.types.json import JSONString
from graphene.utils.str_converters import to_const from graphene.utils.str_converters import to_const
from graphene.relay import Node from graphene.relay import Node

View File

@ -1,8 +1,8 @@
from graphene.contrib.sqlalchemy.types import ( from .types import (
SQLAlchemyObjectType, SQLAlchemyObjectType,
SQLAlchemyNode SQLAlchemyNode
) )
from graphene.contrib.sqlalchemy.fields import ( from .fields import (
SQLAlchemyConnectionField SQLAlchemyConnectionField
) )

View File

@ -3,9 +3,10 @@ from sqlalchemy import types
from sqlalchemy.orm import interfaces from sqlalchemy.orm import interfaces
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from graphene import Enum, ID, Boolean, Float, Int, String, List from graphene import Enum, ID, Boolean, Float, Int, String, List, Field
from graphene.relay import Node
from graphene.types.json import JSONString from graphene.types.json import JSONString
from .fields import ConnectionOrListField, SQLAlchemyModelField from .fields import SQLAlchemyConnectionField
try: try:
from sqlalchemy_utils.types.choice import ChoiceType from sqlalchemy_utils.types.choice import ChoiceType
@ -14,23 +15,27 @@ except ImportError:
pass pass
def convert_sqlalchemy_relationship(relationship): def convert_sqlalchemy_relationship(relationship, registry):
direction = relationship.direction direction = relationship.direction
model = relationship.mapper.entity model = relationship.mapper.entity
model_field = SQLAlchemyModelField(model, description=relationship.doc) _type = registry.get_type_for_model(model)
if not _type:
return None
if direction == interfaces.MANYTOONE: if direction == interfaces.MANYTOONE:
return model_field return Field(_type)
elif (direction == interfaces.ONETOMANY or elif (direction == interfaces.ONETOMANY or
direction == interfaces.MANYTOMANY): direction == interfaces.MANYTOMANY):
return ConnectionOrListField(model_field) if issubclass(_type, Node):
return SQLAlchemyConnectionField(_type)
return List(_type)
def convert_sqlalchemy_column(column): def convert_sqlalchemy_column(column, registry=None):
return convert_sqlalchemy_type(getattr(column, 'type', None), column) return convert_sqlalchemy_type(getattr(column, 'type', None), column, registry)
@singledispatch @singledispatch
def convert_sqlalchemy_type(type, column): def convert_sqlalchemy_type(type, column, registry=None):
raise Exception( raise Exception(
"Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__))
@ -45,14 +50,14 @@ def convert_sqlalchemy_type(type, column):
@convert_sqlalchemy_type.register(types.Enum) @convert_sqlalchemy_type.register(types.Enum)
@convert_sqlalchemy_type.register(postgresql.ENUM) @convert_sqlalchemy_type.register(postgresql.ENUM)
@convert_sqlalchemy_type.register(postgresql.UUID) @convert_sqlalchemy_type.register(postgresql.UUID)
def convert_column_to_string(type, column): def convert_column_to_string(type, column, registry=None):
return String(description=column.doc) return String(description=column.doc)
@convert_sqlalchemy_type.register(types.SmallInteger) @convert_sqlalchemy_type.register(types.SmallInteger)
@convert_sqlalchemy_type.register(types.BigInteger) @convert_sqlalchemy_type.register(types.BigInteger)
@convert_sqlalchemy_type.register(types.Integer) @convert_sqlalchemy_type.register(types.Integer)
def convert_column_to_int_or_id(type, column): def convert_column_to_int_or_id(type, column, registry=None):
if column.primary_key: if column.primary_key:
return ID(description=column.doc) return ID(description=column.doc)
else: else:
@ -60,24 +65,24 @@ def convert_column_to_int_or_id(type, column):
@convert_sqlalchemy_type.register(types.Boolean) @convert_sqlalchemy_type.register(types.Boolean)
def convert_column_to_boolean(type, column): def convert_column_to_boolean(type, column, registry=None):
return Boolean(description=column.doc) return Boolean(description=column.doc)
@convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Float)
@convert_sqlalchemy_type.register(types.Numeric) @convert_sqlalchemy_type.register(types.Numeric)
def convert_column_to_float(type, column): def convert_column_to_float(type, column, registry=None):
return Float(description=column.doc) return Float(description=column.doc)
@convert_sqlalchemy_type.register(ChoiceType) @convert_sqlalchemy_type.register(ChoiceType)
def convert_column_to_enum(type, column): def convert_column_to_enum(type, column, registry=None):
name = '{}_{}'.format(column.table.name, column.name).upper() name = '{}_{}'.format(column.table.name, column.name).upper()
return Enum(name, type.choices, description=column.doc) return Enum(name, type.choices, description=column.doc)
@convert_sqlalchemy_type.register(postgresql.ARRAY) @convert_sqlalchemy_type.register(postgresql.ARRAY)
def convert_postgres_array_to_list(type, column): def convert_postgres_array_to_list(type, column, registry=None):
graphene_type = convert_sqlalchemy_type(column.type.item_type, column) graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
return List(graphene_type, description=column.doc) return List(graphene_type, description=column.doc)
@ -85,5 +90,5 @@ def convert_postgres_array_to_list(type, column):
@convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.HSTORE)
@convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSON)
@convert_sqlalchemy_type.register(postgresql.JSONB) @convert_sqlalchemy_type.register(postgresql.JSONB)
def convert_json_to_string(type, column): def convert_json_to_string(type, column, registry=None):
return JSONString(description=column.doc) return JSONString(description=column.doc)

View File

@ -1,69 +1,35 @@
from ...core.exceptions import SkipField from sqlalchemy.orm.query import Query
from ...core.fields import Field
from ...core.types.base import FieldType
from ...core.types.definitions import List
from ...relay import ConnectionField
from ...relay.utils import is_node
from .utils import get_query, get_type_for_model, maybe_query
from graphene.relay import ConnectionField
class DefaultQuery(object): from graphql_relay.connection.arrayconnection import connection_from_list_slice
pass from .utils import get_query
class SQLAlchemyConnectionField(ConnectionField): class SQLAlchemyConnectionField(ConnectionField):
def __init__(self, *args, **kwargs):
kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery)
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
@property @property
def model(self): def model(self):
return self.type._meta.model return self.connection._meta.node._meta.model
def from_list(self, connection_type, resolved, args, context, info): def get_query(self, context):
if resolved is DefaultQuery: return get_query(self.model, context)
resolved = get_query(self.model, info)
query = maybe_query(resolved)
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info)
def default_resolver(self, root, args, context, info):
return getattr(root, self.source or self.attname, self.get_query(context))
class ConnectionOrListField(Field): @staticmethod
def connection_resolver(resolver, connection, root, args, context, info):
def internal_type(self, schema): iterable = resolver(root, args, context, info)
model_field = self.type if isinstance(iterable, Query):
field_object_type = model_field.get_object_type(schema) _len = iterable.count()
if not field_object_type:
raise SkipField()
if is_node(field_object_type):
field = SQLAlchemyConnectionField(field_object_type)
else: else:
field = Field(List(field_object_type)) _len = len(iterable)
field.contribute_to_class(self.object_type, self.attname) return connection_from_list_slice(
return schema.T(field) iterable,
args,
slice_start=0,
class SQLAlchemyModelField(FieldType): list_length=_len,
list_slice_length=_len,
def __init__(self, model, *args, **kwargs): connection_type=connection,
self.model = model edge_type=connection.Edge,
super(SQLAlchemyModelField, self).__init__(*args, **kwargs)
def internal_type(self, schema):
_type = self.get_object_type(schema)
if not _type and self.parent._meta.only_fields:
raise Exception(
"Table %r is not accessible by the schema. "
"You can either register the type manually "
"using @schema.register. "
"Or disable the field in %s" % (
self.model,
self.parent,
) )
)
if not _type:
raise SkipField()
return schema.T(_type)
def get_object_type(self, schema):
return get_type_for_model(schema, self.model)

View File

@ -1,24 +0,0 @@
from ...core.classtypes.objecttype import ObjectTypeOptions
from ...relay.types import Node
from ...relay.utils import is_node
VALID_ATTRS = ('model', 'only_fields', 'exclude_fields', 'identifier')
class SQLAlchemyOptions(ObjectTypeOptions):
def __init__(self, *args, **kwargs):
super(SQLAlchemyOptions, self).__init__(*args, **kwargs)
self.model = None
self.identifier = "id"
self.valid_attrs += VALID_ATTRS
self.only_fields = None
self.exclude_fields = []
self.filter_fields = None
self.filter_order_by = None
def contribute_to_class(self, cls, name):
super(SQLAlchemyOptions, self).contribute_to_class(cls, name)
if is_node(cls):
self.exclude_fields = list(self.exclude_fields) + ['id']
self.interfaces.append(Node)

View File

@ -0,0 +1,28 @@
class Registry(object):
def __init__(self):
self._registry = {}
self._registry_models = {}
def register(self, cls):
from .types import SQLAlchemyObjectType
assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.'
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):
return self._registry.get(model)
registry = None
def get_global_registry():
global registry
if not registry:
registry = Registry()
return registry
def reset_global_registry():
global registry
registry = None

View File

@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
import graphene import graphene
from graphene.core.types.custom_scalars import JSONString from graphene.types.json import JSONString
from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, from ..fields import SQLAlchemyConnectionField
SQLAlchemyModelField) from ..types import SQLAlchemyObjectType, SQLAlchemyNode
from ..registry import Registry
from .models import Article, Pet, Reporter from .models import Article, Pet, Reporter
@ -100,30 +101,63 @@ def test_should_choice_convert_enum():
Table('translatedmodel', Base.metadata, column) Table('translatedmodel', Base.metadata, column)
graphene_type = convert_sqlalchemy_column(column) graphene_type = convert_sqlalchemy_column(column)
assert issubclass(graphene_type, graphene.Enum) assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.type_name == 'TRANSLATEDMODEL_LANGUAGE' assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE'
assert graphene_type._meta.description == 'Language' assert graphene_type._meta.graphql_type.description == 'Language'
assert graphene_type.__enum__.__members__['es'].value == 'Spanish' assert graphene_type._meta.enum.__members__['es'].value == 'Spanish'
assert graphene_type.__enum__.__members__['en'].value == 'English' assert graphene_type._meta.enum.__members__['en'].value == 'English'
def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist():
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property) registry = Registry()
assert isinstance(graphene_type, ConnectionOrListField) graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
assert isinstance(graphene_type.type, SQLAlchemyModelField) assert not graphene_type
assert graphene_type.type.model == Pet
def test_should_manytomany_convert_connectionorlist_list():
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, graphene.List)
assert graphene_type.of_type == A._meta.graphql_type
def test_should_manytomany_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType):
class Meta:
model = Pet
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, SQLAlchemyConnectionField)
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
field = convert_sqlalchemy_relationship(Article.reporter.property) registry = Registry()
assert isinstance(field, SQLAlchemyModelField) graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry)
assert field.model == Reporter assert not graphene_type
def test_should_onetomany_convert_model(): def test_should_manytoone_convert_connectionorlist_list():
graphene_type = convert_sqlalchemy_relationship(Reporter.articles.property) class A(SQLAlchemyObjectType):
assert isinstance(graphene_type, ConnectionOrListField) class Meta:
assert isinstance(graphene_type.type, SQLAlchemyModelField) model = Reporter
assert graphene_type.type.model == Article
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type
def test_should_manytoone_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType):
class Meta:
model = Reporter
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type
def test_should_postgresql_uuid_convert(): def test_should_postgresql_uuid_convert():

View File

@ -4,8 +4,8 @@ from sqlalchemy.orm import scoped_session, sessionmaker
import graphene import graphene
from graphene import relay from graphene import relay
from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
SQLAlchemyNode, SQLAlchemyObjectType) from ..fields import SQLAlchemyConnectionField
from .models import Article, Base, Editor, Reporter from .models import Article, Base, Editor, Reporter
@ -52,7 +52,7 @@ def test_should_query_well(session):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
reporters = ReporterType.List() reporters = graphene.List(ReporterType)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return session.query(Reporter).first() return session.query(Reporter).first()
@ -93,7 +93,7 @@ def test_should_query_well(session):
def test_should_node(session): def test_should_node(session):
setup_fixtures(session) setup_fixtures(session)
class ReporterNode(SQLAlchemyNode): class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
@ -105,7 +105,7 @@ def test_should_node(session):
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, *args, **kwargs):
return [Article(headline='Hi!')] return [Article(headline='Hi!')]
class ArticleNode(SQLAlchemyNode): class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType):
class Meta: class Meta:
model = Article model = Article
@ -115,7 +115,7 @@ def test_should_node(session):
# return Article(id=1, headline='Article node') # return Article(id=1, headline='Article node')
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = relay.NodeField() node = SQLAlchemyNode.Field()
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNode)
@ -185,8 +185,8 @@ def test_should_node(session):
'headline': 'Hi!' 'headline': 'Hi!'
} }
} }
schema = graphene.Schema(query=Query, session=session) schema = graphene.Schema(query=Query)
result = schema.execute(query) result = schema.execute(query, context_value={'session': session})
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected
@ -194,14 +194,13 @@ def test_should_node(session):
def test_should_custom_identifier(session): def test_should_custom_identifier(session):
setup_fixtures(session) setup_fixtures(session)
class EditorNode(SQLAlchemyNode): class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType):
class Meta: class Meta:
model = Editor model = Editor
identifier = "editor_id"
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = relay.NodeField(EditorNode) node = SQLAlchemyNode.Field()
all_editors = SQLAlchemyConnectionField(EditorNode) all_editors = SQLAlchemyConnectionField(EditorNode)
query = ''' query = '''
@ -215,9 +214,11 @@ def test_should_custom_identifier(session):
} }
}, },
node(id: "RWRpdG9yTm9kZTox") { node(id: "RWRpdG9yTm9kZTox") {
...on EditorNode {
name name
} }
} }
}
''' '''
expected = { expected = {
'allEditors': { 'allEditors': {
@ -233,7 +234,7 @@ def test_should_custom_identifier(session):
} }
} }
schema = graphene.Schema(query=Query, session=session) schema = graphene.Schema(query=Query)
result = schema.execute(query) result = schema.execute(query, context_value={'session': session})
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected

View File

@ -1,25 +1,24 @@
from py.test import raises from py.test import raises
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType from ..types import SQLAlchemyObjectType
from tests.utils import assert_equal_lists
from .models import Reporter from .models import Reporter
from ..registry import Registry
def test_should_raise_if_no_model(): def test_should_raise_if_no_model():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character1(SQLAlchemyObjectType): class Character1(SQLAlchemyObjectType):
pass pass
assert 'model in the Meta' in str(excinfo.value) assert 'valid SQLAlchemy Model' in str(excinfo.value)
def test_should_raise_if_model_is_invalid(): def test_should_raise_if_model_is_invalid():
with raises(Exception) as excinfo: with raises(Exception) as excinfo:
class Character2(SQLAlchemyObjectType): class Character2(SQLAlchemyObjectType):
class Meta: class Meta:
model = 1 model = 1
assert 'not a SQLAlchemy model' in str(excinfo.value) assert 'valid SQLAlchemy Model' in str(excinfo.value)
def test_should_map_fields_correctly(): def test_should_map_fields_correctly():
@ -27,10 +26,9 @@ def test_should_map_fields_correctly():
class Meta: class Meta:
model = Reporter model = Reporter
assert_equal_lists( registry = Registry()
ReporterType2._meta.fields_map.keys(),
['articles', 'first_name', 'last_name', 'email', 'pets', 'id'] assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email']
)
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():
@ -38,8 +36,5 @@ def test_should_map_only_few_fields():
class Meta: class Meta:
model = Reporter model = Reporter
only_fields = ('id', 'email') only = ('id', 'email')
assert_equal_lists( assert Reporter2._meta.get_fields().keys() == ['id', 'email']
Reporter2._meta.fields_map.keys(),
['id', 'email']
)

View File

@ -1,38 +1,42 @@
from graphql.type import GraphQLObjectType from graphql.type import GraphQLObjectType, GraphQLInterfaceType
from graphql.type.definition import GraphQLFieldDefinition
from graphql import GraphQLInt
from pytest import raises from pytest import raises
from graphene import Schema from graphene import Schema
from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode, from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
SQLAlchemyObjectType) from ..registry import Registry
from graphene.core.fields import Field
from graphene.core.types.scalars import Int from graphene import Field, Int
from graphene.relay.fields import GlobalIDField # from tests.utils import assert_equal_lists
from tests.utils import assert_equal_lists
from .models import Article, Reporter from .models import Article, Reporter
schema = Schema() registry = Registry()
class Character(SQLAlchemyObjectType): class Character(SQLAlchemyObjectType):
'''Character description''' '''Character description'''
class Meta: class Meta:
model = Reporter model = Reporter
registry = registry
@schema.register class Human(SQLAlchemyNode, SQLAlchemyObjectType):
class Human(SQLAlchemyNode):
'''Human description''' '''Human description'''
pub_date = Int() pub_date = Int()
class Meta: class Meta:
model = Article model = Article
exclude_fields = ('id', ) exclude = ('id', )
registry = registry
def test_sqlalchemy_interface(): def test_sqlalchemy_interface():
assert SQLAlchemyNode._meta.interface is True assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType)
# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
@ -43,60 +47,31 @@ def test_sqlalchemy_interface():
def test_objecttype_registered(): def test_objecttype_registered():
object_type = schema.T(Character) object_type = Character._meta.graphql_type
assert isinstance(object_type, GraphQLObjectType) assert isinstance(object_type, GraphQLObjectType)
assert Character._meta.model == Reporter assert Character._meta.model == Reporter
assert_equal_lists( assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email']
object_type.get_fields().keys(),
['articles', 'firstName', 'lastName', 'email', 'id']
)
def test_sqlalchemynode_idfield(): # def test_sqlalchemynode_idfield():
idfield = SQLAlchemyNode._meta.fields_map['id'] # idfield = SQLAlchemyNode._meta.fields_map['id']
assert isinstance(idfield, GlobalIDField) # assert isinstance(idfield, GlobalIDField)
def test_node_idfield(): # def test_node_idfield():
idfield = Human._meta.fields_map['id'] # idfield = Human._meta.fields_map['id']
assert isinstance(idfield, GlobalIDField) # assert isinstance(idfield, GlobalIDField)
def test_node_replacedfield(): def test_node_replacedfield():
idfield = Human._meta.fields_map['pub_date'] idfield = Human._meta.graphql_type.get_fields()['pubDate']
assert isinstance(idfield, Field) assert isinstance(idfield, GraphQLFieldDefinition)
assert schema.T(idfield).type == schema.T(Int()) assert idfield.type == GraphQLInt
def test_interface_objecttype_init_none():
h = Human()
assert h._root is None
def test_interface_objecttype_init_good():
instance = Article()
h = Human(instance)
assert h._root == instance
def test_interface_objecttype_init_unexpected():
with raises(AssertionError) as excinfo:
Human(object())
assert str(excinfo.value) == "Human received a non-compatible instance (object) when expecting Article"
def test_object_type(): def test_object_type():
object_type = schema.T(Human) object_type = Human._meta.graphql_type
Human._meta.fields_map object_type.get_fields()
assert Human._meta.interface is False
assert isinstance(object_type, GraphQLObjectType) assert isinstance(object_type, GraphQLObjectType)
assert_equal_lists( assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId']
object_type.get_fields().keys(), assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces()
['headline', 'id', 'reporter', 'reporterId', 'pubDate']
)
assert schema.T(SQLAlchemyNode) in object_type.get_interfaces()
def test_node_notinterface():
assert Human._meta.interface is False
assert SQLAlchemyNode in Human._meta.interfaces

View File

@ -5,13 +5,12 @@ from ..utils import get_session
def test_get_session(): def test_get_session():
session = 'My SQLAlchemy session' session = 'My SQLAlchemy session'
schema = Schema(session=session)
class Query(ObjectType): class Query(ObjectType):
x = String() x = String()
def resolve_x(self, args, info): def resolve_x(self, args, context, info):
return get_session(info) return get_session(context)
query = ''' query = '''
query ReporterQuery { query ReporterQuery {
@ -19,7 +18,7 @@ def test_get_session():
} }
''' '''
schema = Schema(query=Query, session=session) schema = Schema(query=Query)
result = schema.execute(query) result = schema.execute(query, context_value={'session': session})
assert not result.errors assert not result.errors
assert result.data['x'] == session assert result.data['x'] == session

View File

@ -1,125 +1,158 @@
import inspect
import six import six
from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta from graphene import ObjectType
from ...relay.types import Connection, Node, NodeMeta from graphene.relay import Node
from .converter import (convert_sqlalchemy_column, from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from .options import SQLAlchemyOptions from .utils import is_mapped
from .utils import get_query, is_mapped
from functools import partial
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): from graphene import Field, Interface
options_class = SQLAlchemyOptions from graphene.types.options import Options
from graphene.types.objecttype import attrs_without_fields, get_interfaces
def construct_fields(cls): from .registry import Registry, get_global_registry
only_fields = cls._meta.only_fields from .utils import get_query
exclude_fields = cls._meta.exclude_fields from graphene.utils.is_base_type import is_base_type
already_created_fields = {f.attname for f in cls._meta.local_fields} from graphene.utils.copy_fields import copy_fields
from graphene.utils.get_graphql_type import get_graphql_type
from graphene.utils.get_fields import get_fields
from graphene.utils.as_field import as_field
from graphene.generators import generate_objecttype
class SQLAlchemyObjectTypeMeta(type(ObjectType)):
def _construct_fields(cls, fields, options):
only_fields = cls._meta.only
exclude_fields = cls._meta.exclude
inspected_model = sqlalchemyinspect(cls._meta.model) inspected_model = sqlalchemyinspect(cls._meta.model)
# Get all the columns for the relationships on the model # Get all the columns for the relationships on the model
for relationship in inspected_model.relationships: for relationship in inspected_model.relationships:
is_not_in_only = only_fields and relationship.key not in only_fields is_not_in_only = only_fields and relationship.key not in only_fields
is_already_created = relationship.key in already_created_fields is_already_created = relationship.key in fields
is_excluded = relationship.key in exclude_fields or is_already_created is_excluded = relationship.key in exclude_fields or is_already_created
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields # in there. Or when we excldue this field in exclude_fields
continue continue
converted_relationship = convert_sqlalchemy_relationship(relationship) converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
cls.add_to_class(relationship.key, converted_relationship) if not converted_relationship:
continue
name = relationship.key
fields[name] = as_field(converted_relationship)
for name, column in inspected_model.columns.items(): for name, column in inspected_model.columns.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_already_created = name in already_created_fields is_already_created = name in fields
is_excluded = name in exclude_fields or is_already_created is_excluded = name in exclude_fields or is_already_created
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields # in there. Or when we excldue this field in exclude_fields
continue continue
converted_column = convert_sqlalchemy_column(column) converted_column = convert_sqlalchemy_column(column, options.registry)
cls.add_to_class(name, converted_column) if not converted_column:
continue
fields[name] = as_field(converted_column)
def construct(cls, *args, **kwargs): fields = copy_fields(Field, fields, parent=cls)
cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs)
if not cls._meta.abstract: return fields
if not cls._meta.model:
raise Exception( @staticmethod
'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % def _create_objecttype(cls, name, bases, attrs):
cls) # super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__
elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model): super_new = type.__new__
raise Exception('Provided model in %s is not a SQLAlchemy model' % cls)
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, SQLAlchemyObjectTypeMeta):
return super_new(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
model=None,
fields=(),
exclude=(),
only=(),
interfaces=(),
registry=None
)
if not options.registry:
options.registry = get_global_registry()
assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry)
assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model)
interfaces = tuple(options.interfaces)
fields = get_fields(ObjectType, attrs, bases, interfaces)
attrs = attrs_without_fields(attrs, fields)
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
base_interfaces = tuple(b for b in bases if issubclass(b, Interface))
options.get_fields = partial(cls._construct_fields, fields, options)
options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces))
options.graphql_type = generate_objecttype(cls)
if issubclass(cls, SQLAlchemyObjectType):
options.registry.register(cls)
cls.construct_fields()
return cls return cls
class InstanceObjectType(ObjectType): class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)):
is_type_of = None
class Meta:
abstract = True
def __init__(self, _root=None):
super(InstanceObjectType, self).__init__(_root=_root)
assert not self._root or isinstance(self._root, self._meta.model), (
'{} received a non-compatible instance ({}) '
'when expecting {}'.format(
self.__class__.__name__,
self._root.__class__.__name__,
self._meta.model.__name__
))
@property
def instance(self):
return self._root
@instance.setter
def instance(self, value):
self._root = value
class SQLAlchemyObjectType(six.with_metaclass( class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)):
SQLAlchemyObjectTypeMeta, InstanceObjectType)):
class Meta: @staticmethod
abstract = True def _get_interface_options(meta):
return Options(
meta,
name=None,
description=None,
model=None,
graphql_type=None,
registry=False
)
@staticmethod
def _create_interface(cls, name, bases, attrs):
cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs)
if not cls._meta.registry:
cls._meta.registry = get_global_registry()
assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name)
return cls
class SQLAlchemyConnection(Connection): class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
pass
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta):
pass
class NodeInstance(Node, InstanceObjectType):
class Meta:
abstract = True
class SQLAlchemyNode(six.with_metaclass(
SQLAlchemyNodeMeta, NodeInstance)):
class Meta:
abstract = True
def to_global_id(self):
id_ = getattr(self.instance, self._meta.identifier)
return self.global_id(id_)
@classmethod @classmethod
def get_node(cls, id, info=None): def get_node(cls, id, context, info):
try: try:
model = cls._meta.model model = cls._meta.model
identifier = cls._meta.identifier query = get_query(model, context)
query = get_query(model, info) return query.get(id)
instance = query.filter(getattr(model, identifier) == id).one()
return cls(instance)
except NoResultFound: except NoResultFound:
return None return None
@classmethod
def resolve_id(cls, root, args, context, info):
return root.__mapper__.primary_key_from_instance(root)[0]
@classmethod
def resolve_type(cls, type_instance, context, info):
# We get the model from the _meta in the SQLAlchemy class/instance
model = type(type_instance)
graphene_type = cls._meta.registry.get_type_for_model(model)
if graphene_type:
return get_graphql_type(graphene_type)
raise Exception("Type not found for model \"{}\"".format(model))

View File

@ -1,28 +1,14 @@
from sqlalchemy.ext.declarative.api import DeclarativeMeta from sqlalchemy.ext.declarative.api import DeclarativeMeta
from sqlalchemy.orm.query import Query
from graphene.utils import LazyList
def get_type_for_model(schema, model): def get_session(context):
schema = schema return context.get('session')
types = schema.types.values()
for _type in types:
type_model = hasattr(_type, '_meta') and getattr(
_type._meta, 'model', None)
if model == type_model:
return _type
def get_session(info): def get_query(model, context):
schema = info.schema.graphene_schema
return schema.options.get('session')
def get_query(model, info):
query = getattr(model, 'query', None) query = getattr(model, 'query', None)
if not query: if not query:
session = get_session(info) session = get_session(context)
if not session: if not session:
raise Exception('A query in the model Base or a session in the schema is required for querying.\n' raise Exception('A query in the model Base or a session in the schema is required for querying.\n'
'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying') 'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying')
@ -30,20 +16,5 @@ def get_query(model, info):
return query return query
class WrappedQuery(LazyList):
def __len__(self):
# Dont calculate the length using len(query), as this will
# evaluate the whole queryset and return it's length.
# Use .count() instead
return self._origin.count()
def maybe_query(value):
if isinstance(value, Query):
return WrappedQuery(value)
return value
def is_mapped(obj): def is_mapped(obj):
return isinstance(obj, DeclarativeMeta) return isinstance(obj, DeclarativeMeta)

View File

@ -96,11 +96,10 @@ class IterableConnectionField(Field):
@property @property
def connection(self): def connection(self):
from .node import Node from .node import Node
graphql_type = super(IterableConnectionField, self).type if issubclass(self._type, Node):
if issubclass(graphql_type.graphene_type, Node): connection_type = self._type.get_default_connection()
connection_type = graphql_type.graphene_type.get_default_connection()
else: else:
connection_type = graphql_type.graphene_type connection_type = self._type
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
return connection_type return connection_type

View File

@ -84,9 +84,13 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
# return to_global_id(type, id) # return to_global_id(type, id)
# raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) # raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__))
@classmethod
def resolve_id(cls, root, args, context, info):
return getattr(root, 'id', None)
@classmethod @classmethod
def id_resolver(cls, root, args, context, info): def id_resolver(cls, root, args, context, info):
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info))
@classmethod @classmethod
def get_node_from_global_id(cls, global_id, context, info): def get_node_from_global_id(cls, global_id, context, info):