Updated SQLAlchemy code to work with latest version of graphene.

This commit is contained in:
Syrus Akbary 2016-08-14 15:42:27 -07:00
parent c414b3f688
commit f296a2a73f
10 changed files with 161 additions and 179 deletions

View File

@ -1,33 +1,35 @@
import graphene import graphene
from graphene import relay from graphene import relay
from graphene_sqlalchemy import (SQLAlchemyConnectionField, from graphene_sqlalchemy import (SQLAlchemyConnectionField,
SQLAlchemyObjectType, SQLAlchemyObjectType)
SQLAlchemyNode)
from models import Department as DepartmentModel from models import Department as DepartmentModel
from models import Employee as EmployeeModel from models import Employee as EmployeeModel
from models import Role as RoleModel from models import Role as RoleModel
class Department(SQLAlchemyNode, SQLAlchemyObjectType): class Department(SQLAlchemyObjectType):
class Meta: class Meta:
model = DepartmentModel model = DepartmentModel
interfaces = (relay.Node, )
class Employee(SQLAlchemyNode, SQLAlchemyObjectType): class Employee(SQLAlchemyObjectType):
class Meta: class Meta:
model = EmployeeModel model = EmployeeModel
interfaces = (relay.Node, )
class Role(SQLAlchemyNode, SQLAlchemyObjectType): class Role(SQLAlchemyObjectType):
class Meta: class Meta:
model = RoleModel model = RoleModel
interfaces = (relay.Node, )
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field() node = relay.Node.Field()
all_employees = SQLAlchemyConnectionField(Employee) all_employees = SQLAlchemyConnectionField(Employee)
all_roles = SQLAlchemyConnectionField(Role) all_roles = SQLAlchemyConnectionField(Role)
role = graphene.Field(Role) role = graphene.Field(Role)

View File

@ -1,10 +1,9 @@
from .types import ( from .types import (
SQLAlchemyObjectType, SQLAlchemyObjectType,
SQLAlchemyNode
) )
from .fields import ( from .fields import (
SQLAlchemyConnectionField SQLAlchemyConnectionField
) )
__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyNode', __all__ = ['SQLAlchemyObjectType',
'SQLAlchemyConnectionField'] 'SQLAlchemyConnectionField']

View File

@ -3,8 +3,8 @@ from sqlalchemy import types
from sqlalchemy.orm import interfaces from sqlalchemy.orm import interfaces
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic
from graphene.relay import Node from graphene.relay import is_node
from graphene.types.json import JSONString from graphene.types.json import JSONString
from .fields import SQLAlchemyConnectionField from .fields import SQLAlchemyConnectionField
@ -18,16 +18,20 @@ except ImportError:
def convert_sqlalchemy_relationship(relationship, registry): def convert_sqlalchemy_relationship(relationship, registry):
direction = relationship.direction direction = relationship.direction
model = relationship.mapper.entity model = relationship.mapper.entity
_type = registry.get_type_for_model(model)
if not _type: def dynamic_type():
return None _type = registry.get_type_for_model(model)
if direction == interfaces.MANYTOONE: if not _type:
return Field(_type) return None
elif (direction == interfaces.ONETOMANY or if direction == interfaces.MANYTOONE:
direction == interfaces.MANYTOMANY): return Field(_type)
if issubclass(_type, Node): elif (direction == interfaces.ONETOMANY or
return SQLAlchemyConnectionField(_type) direction == interfaces.MANYTOMANY):
return List(_type) if is_node(_type):
return SQLAlchemyConnectionField(_type)
return Field(List(_type))
return Dynamic(dynamic_type)
def convert_sqlalchemy_column(column, registry=None): def convert_sqlalchemy_column(column, registry=None):

View File

@ -1,3 +1,4 @@
from functools import partial
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from graphene.relay import ConnectionField from graphene.relay import ConnectionField
@ -9,17 +10,13 @@ class SQLAlchemyConnectionField(ConnectionField):
@property @property
def model(self): def model(self):
return self.connection._meta.node._meta.model return self.type._meta.node._meta.model
def get_query(self, context):
return get_query(self.model, context)
def default_resolver(self, root, args, context, info):
return getattr(root, self.source or self.attname, self.get_query(context))
@staticmethod @staticmethod
def connection_resolver(resolver, connection, root, args, context, info): def connection_resolver(resolver, connection, model, root, args, context, info):
iterable = resolver(root, args, context, info) iterable = resolver(root, args, context, info)
if iterable is None:
iterable = get_query(model, context)
if isinstance(iterable, Query): if isinstance(iterable, Query):
_len = iterable.count() _len = iterable.count()
else: else:
@ -33,3 +30,6 @@ class SQLAlchemyConnectionField(ConnectionField):
connection_type=connection, connection_type=connection,
edge_type=connection.Edge, edge_type=connection.Edge,
) )
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

View File

@ -7,10 +7,10 @@ class Registry(object):
from .types import SQLAlchemyObjectType from .types import SQLAlchemyObjectType
assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__) assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.' assert cls._meta.registry == self, 'Registry for a Model have to match.'
assert self.get_type_for_model(cls._meta.model) in [None, cls], ( # assert self.get_type_for_model(cls._meta.model) in [None, cls], (
'SQLAlchemy model "{}" already associated with ' # 'SQLAlchemy model "{}" already associated with '
'another type "{}".' # 'another type "{}".'
).format(cls._meta.model, self._registry[cls._meta.model]) # ).format(cls._meta.model, self._registry[cls._meta.model])
self._registry[cls._meta.model] = cls self._registry[cls._meta.model] = cls
def get_type_for_model(self, model): def get_type_for_model(self, model):

View File

@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
import graphene import graphene
from graphene.relay import Node
from graphene.types.json import JSONString from graphene.types.json import JSONString
from ..converter import (convert_sqlalchemy_column, from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from ..fields import SQLAlchemyConnectionField from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType, SQLAlchemyNode from ..types import SQLAlchemyObjectType
from ..registry import Registry from ..registry import Registry
from .models import Article, Pet, Reporter from .models import Article, Pet, Reporter
@ -19,7 +20,7 @@ def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs):
column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs) column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs)
graphene_type = convert_sqlalchemy_column(column) graphene_type = convert_sqlalchemy_column(column)
assert isinstance(graphene_type, graphene_field) assert isinstance(graphene_type, graphene_field)
field = graphene_type.as_field() field = graphene_type.Field()
assert field.description == 'Custom Help Text' assert field.description == 'Custom Help Text'
return field return field
@ -101,16 +102,17 @@ def test_should_choice_convert_enum():
Table('translatedmodel', Base.metadata, column) Table('translatedmodel', Base.metadata, column)
graphene_type = convert_sqlalchemy_column(column) graphene_type = convert_sqlalchemy_column(column)
assert issubclass(graphene_type, graphene.Enum) assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE' assert graphene_type._meta.name == 'TRANSLATEDMODEL_LANGUAGE'
assert graphene_type._meta.graphql_type.description == 'Language' assert graphene_type._meta.description == 'Language'
assert graphene_type._meta.enum.__members__['es'].value == 'Spanish' assert graphene_type._meta.enum.__members__['es'].value == 'Spanish'
assert graphene_type._meta.enum.__members__['en'].value == 'English' assert graphene_type._meta.enum.__members__['en'].value == 'English'
def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist():
registry = Registry() registry = Registry()
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
assert not graphene_type assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
def test_should_manytomany_convert_connectionorlist_list(): def test_should_manytomany_convert_connectionorlist_list():
@ -118,26 +120,30 @@ def test_should_manytomany_convert_connectionorlist_list():
class Meta: class Meta:
model = Pet model = Pet
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, graphene.List) assert isinstance(dynamic_field, graphene.Dynamic)
assert graphene_type.of_type == A._meta.graphql_type graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field)
assert isinstance(graphene_type.type, graphene.List)
assert graphene_type.type.of_type == A
def test_should_manytomany_convert_connectionorlist_connection(): def test_should_manytomany_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType): class A(SQLAlchemyObjectType):
class Meta: class Meta:
model = Pet model = Pet
interfaces = (Node, )
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, SQLAlchemyConnectionField) assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField)
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
registry = Registry() registry = Registry()
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry)
assert not graphene_type assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
def test_should_manytoone_convert_connectionorlist_list(): def test_should_manytoone_convert_connectionorlist_list():
@ -145,19 +151,24 @@ def test_should_manytoone_convert_connectionorlist_list():
class Meta: class Meta:
model = Reporter model = Reporter
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type assert graphene_type.type == A
def test_should_manytoone_convert_connectionorlist_connection(): def test_should_manytoone_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType): class A(SQLAlchemyObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, )
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type assert graphene_type.type == A
def test_should_postgresql_uuid_convert(): def test_should_postgresql_uuid_convert():

View File

@ -3,8 +3,8 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
import graphene import graphene
from graphene import relay from graphene.relay import Node
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) from ..types import SQLAlchemyObjectType
from ..fields import SQLAlchemyConnectionField from ..fields import SQLAlchemyConnectionField
from .models import Article, Base, Editor, Reporter from .models import Article, Base, Editor, Reporter
@ -93,10 +93,11 @@ def test_should_query_well(session):
def test_should_node(session): def test_should_node(session):
setup_fixtures(session) setup_fixtures(session)
class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType): class ReporterNode(SQLAlchemyObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
interfaces = (Node, )
@classmethod @classmethod
def get_node(cls, id, info): def get_node(cls, id, info):
@ -105,17 +106,18 @@ def test_should_node(session):
def resolve_articles(self, *args, **kwargs): def resolve_articles(self, *args, **kwargs):
return [Article(headline='Hi!')] return [Article(headline='Hi!')]
class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType): class ArticleNode(SQLAlchemyObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, )
# @classmethod # @classmethod
# def get_node(cls, id, info): # def get_node(cls, id, info):
# return Article(id=1, headline='Article node') # return Article(id=1, headline='Article node')
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field() node = Node.Field()
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNode)
@ -194,13 +196,14 @@ def test_should_node(session):
def test_should_custom_identifier(session): def test_should_custom_identifier(session):
setup_fixtures(session) setup_fixtures(session)
class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType): class EditorNode(SQLAlchemyObjectType):
class Meta: class Meta:
model = Editor model = Editor
interfaces = (Node, )
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field() node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode) all_editors = SQLAlchemyConnectionField(EditorNode)
query = ''' query = '''

View File

@ -28,7 +28,7 @@ def test_should_map_fields_correctly():
model = Reporter model = Reporter
registry = Registry() registry = Registry()
assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email'] assert ReporterType2._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles']
def test_should_map_only_few_fields(): def test_should_map_only_few_fields():
@ -36,5 +36,5 @@ def test_should_map_only_few_fields():
class Meta: class Meta:
model = Reporter model = Reporter
only = ('id', 'email') only_fields = ('id', 'email')
assert Reporter2._meta.get_fields().keys() == ['id', 'email'] assert Reporter2._meta.fields.keys() == ['id', 'email']

View File

@ -1,10 +1,10 @@
from graphql.type import GraphQLObjectType, GraphQLInterfaceType from graphql.type import GraphQLObjectType, GraphQLInterfaceType
from graphql.type.definition import GraphQLFieldDefinition
from graphql import GraphQLInt from graphql import GraphQLInt
from pytest import raises from pytest import raises
from graphene import Schema from graphene import Schema, Interface, ObjectType
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) from graphene.relay import Node, is_node
from ..types import SQLAlchemyObjectType
from ..registry import Registry from ..registry import Registry
from graphene import Field, Int from graphene import Field, Int
@ -14,6 +14,7 @@ from .models import Article, Reporter
registry = Registry() registry = Registry()
class Character(SQLAlchemyObjectType): class Character(SQLAlchemyObjectType):
'''Character description''' '''Character description'''
class Meta: class Meta:
@ -21,22 +22,21 @@ class Character(SQLAlchemyObjectType):
registry = registry registry = registry
class Human(SQLAlchemyNode, SQLAlchemyObjectType): class Human(SQLAlchemyObjectType):
'''Human description''' '''Human description'''
pub_date = Int() pub_date = Int()
class Meta: class Meta:
model = Article model = Article
exclude = ('id', ) exclude_fields = ('id', )
registry = registry registry = registry
interfaces = (Node, )
def test_sqlalchemy_interface(): def test_sqlalchemy_interface():
assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType) assert issubclass(Node, Interface)
assert issubclass(Node, Node)
# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
@ -47,14 +47,13 @@ def test_sqlalchemy_interface():
def test_objecttype_registered(): def test_objecttype_registered():
object_type = Character._meta.graphql_type assert issubclass(Character, ObjectType)
assert isinstance(object_type, GraphQLObjectType)
assert Character._meta.model == Reporter assert Character._meta.model == Reporter
assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email'] assert Character._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles']
# def test_sqlalchemynode_idfield(): # def test_sqlalchemynode_idfield():
# idfield = SQLAlchemyNode._meta.fields_map['id'] # idfield = Node._meta.fields_map['id']
# assert isinstance(idfield, GlobalIDField) # assert isinstance(idfield, GlobalIDField)
@ -64,14 +63,12 @@ def test_objecttype_registered():
def test_node_replacedfield(): def test_node_replacedfield():
idfield = Human._meta.graphql_type.get_fields()['pubDate'] idfield = Human._meta.fields['pub_date']
assert isinstance(idfield, GraphQLFieldDefinition) assert isinstance(idfield, Field)
assert idfield.type == GraphQLInt assert idfield.type == Int
def test_object_type(): def test_object_type():
object_type = Human._meta.graphql_type assert issubclass(Human, ObjectType)
object_type.get_fields() assert Human._meta.fields.keys() == ['id', 'pub_date', 'headline', 'reporter_id', 'reporter']
assert isinstance(object_type, GraphQLObjectType) assert is_node(Human)
assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId']
assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces()

View File

@ -1,138 +1,112 @@
from collections import OrderedDict
import six import six
from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from graphene import ObjectType from graphene import ObjectType
from graphene.relay import Node from graphene.relay import is_node
from .converter import (convert_sqlalchemy_column, from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship) convert_sqlalchemy_relationship)
from .utils import is_mapped from .utils import is_mapped
from functools import partial from graphene.types.objecttype import ObjectTypeMeta
from graphene import Field, Interface
from graphene.types.options import Options from graphene.types.options import Options
from graphene.types.objecttype import attrs_without_fields, get_interfaces
from .registry import Registry, get_global_registry from .registry import Registry, get_global_registry
from .utils import get_query
from graphene.utils.is_base_type import is_base_type from graphene.utils.is_base_type import is_base_type
from graphene.utils.copy_fields import copy_fields from graphene.types.utils import get_fields_in_type
from graphene.utils.get_graphql_type import get_graphql_type from .utils import get_query
from graphene.utils.get_fields import get_fields
from graphene.utils.as_field import as_field
from graphene.generators import generate_objecttype
class SQLAlchemyObjectTypeMeta(type(ObjectType)): class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
def _construct_fields(cls, fields, options): def _construct_fields(cls, all_fields, options):
only_fields = cls._meta.only only_fields = cls._meta.only_fields
exclude_fields = cls._meta.exclude exclude_fields = cls._meta.exclude_fields
inspected_model = sqlalchemyinspect(cls._meta.model) inspected_model = sqlalchemyinspect(cls._meta.model)
# Get all the columns for the relationships on the model fields = OrderedDict()
for relationship in inspected_model.relationships:
is_not_in_only = only_fields and relationship.key not in only_fields
is_already_created = relationship.key in fields
is_excluded = relationship.key in exclude_fields or is_already_created
if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
if not converted_relationship:
continue
name = relationship.key
fields[name] = as_field(converted_relationship)
for name, column in inspected_model.columns.items(): for name, column in inspected_model.columns.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
is_already_created = name in fields is_already_created = name in all_fields
is_excluded = name in exclude_fields or is_already_created is_excluded = name in exclude_fields or is_already_created
if is_not_in_only or is_excluded: if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields # in there. Or when we excldue this field in exclude_fields
continue continue
converted_column = convert_sqlalchemy_column(column, options.registry) converted_column = convert_sqlalchemy_column(column, options.registry)
if not converted_column: fields[name] = converted_column
continue
fields[name] = as_field(converted_column)
fields = copy_fields(Field, fields, parent=cls) # Get all the columns for the relationships on the model
for relationship in inspected_model.relationships:
is_not_in_only = only_fields and relationship.key not in only_fields
is_already_created = relationship.key in all_fields
is_excluded = relationship.key in exclude_fields or is_already_created
if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
name = relationship.key
fields[name] = converted_relationship
return fields return fields
@staticmethod @staticmethod
def _create_objecttype(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
# super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__
super_new = type.__new__
# Also ensure initialization is only performed for subclasses of Model # Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself). # (excluding Model class itself).
if not is_base_type(bases, SQLAlchemyObjectTypeMeta): if not is_base_type(bases, SQLAlchemyObjectTypeMeta):
return super_new(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
options = Options( options = Options(
attrs.pop('Meta', None), attrs.pop('Meta', None),
name=None, name=name,
description=None, description=attrs.pop('__doc__', None),
model=None, model=None,
fields=(), fields=None,
exclude=(), only_fields=(),
only=(), exclude_fields=(),
id='id',
interfaces=(), interfaces=(),
registry=None registry=None
) )
if not options.registry: if not options.registry:
options.registry = get_global_registry() options.registry = get_global_registry()
assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry) assert isinstance(options.registry, Registry), (
assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model) 'The attribute registry in {}.Meta needs to be an'
' instance of Registry, received "{}".'
).format(name, options.registry)
assert is_mapped(options.model), (
'You need to pass a valid SQLAlchemy Model in '
'{}.Meta, received "{}".'
).format(name, options.model)
interfaces = tuple(options.interfaces)
fields = get_fields(ObjectType, attrs, bases, interfaces)
attrs = attrs_without_fields(attrs, fields)
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options))
options.get_fields = partial(cls._construct_fields, fields, options)
options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces))
options.graphql_type = generate_objecttype(cls) options.registry.register(cls)
if issubclass(cls, SQLAlchemyObjectType): options.sqlalchemy_fields = get_fields_in_type(
options.registry.register(cls) ObjectType,
cls._construct_fields(options.fields, options)
)
options.fields.update(options.sqlalchemy_fields)
return cls return cls
class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)): class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)):
is_type_of = None @classmethod
def is_type_of(cls, root, context, info):
if isinstance(root, cls):
return True
if not is_mapped(type(root)):
raise Exception((
'Received incompatible instance "{}".'
).format(root))
return type(root) == cls._meta.model
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)):
@staticmethod
def _get_interface_options(meta):
return Options(
meta,
name=None,
description=None,
graphql_type=None,
registry=False
)
@staticmethod
def _create_interface(cls, name, bases, attrs):
cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs)
if not cls._meta.registry:
cls._meta.registry = get_global_registry()
assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name)
return cls
class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
@classmethod @classmethod
def get_node(cls, id, context, info): def get_node(cls, id, context, info):
try: try:
@ -142,16 +116,8 @@ class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
except NoResultFound: except NoResultFound:
return None return None
@classmethod def resolve_id(root, args, context, info):
def resolve_id(cls, root, args, context, info): graphene_type = info.parent_type.graphene_type
return root.__mapper__.primary_key_from_instance(root)[0] if is_node(graphene_type):
return root.__mapper__.primary_key_from_instance(root)[0]
@classmethod return getattr(root, graphene_type._meta.id, None)
def resolve_type(cls, type_instance, context, info):
# We get the model from the _meta in the SQLAlchemy class/instance
model = type(type_instance)
graphene_type = cls._meta.registry.get_type_for_model(model)
if graphene_type:
return get_graphql_type(graphene_type)
raise Exception("Type not found for model \"{}\"".format(model))