From f296a2a73f9b5ebbd1598ac9c9aa5bb17d41eb12 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Sun, 14 Aug 2016 15:42:27 -0700 Subject: [PATCH] Updated SQLAlchemy code to work with latest version of graphene. --- .../examples/flask_sqlalchemy/schema.py | 14 +- .../graphene_sqlalchemy/__init__.py | 3 +- .../graphene_sqlalchemy/converter.py | 28 ++-- .../graphene_sqlalchemy/fields.py | 16 +- .../graphene_sqlalchemy/registry.py | 8 +- .../tests/test_converter.py | 53 +++--- .../graphene_sqlalchemy/tests/test_query.py | 17 +- .../graphene_sqlalchemy/tests/test_schema.py | 6 +- .../graphene_sqlalchemy/tests/test_types.py | 39 ++--- .../graphene_sqlalchemy/types.py | 156 +++++++----------- 10 files changed, 161 insertions(+), 179 deletions(-) diff --git a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py index b9cf38fb..df967dc4 100644 --- a/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py +++ b/graphene-sqlalchemy/examples/flask_sqlalchemy/schema.py @@ -1,33 +1,35 @@ import graphene from graphene import relay from graphene_sqlalchemy import (SQLAlchemyConnectionField, - SQLAlchemyObjectType, - SQLAlchemyNode) + SQLAlchemyObjectType) from models import Department as DepartmentModel from models import Employee as EmployeeModel from models import Role as RoleModel -class Department(SQLAlchemyNode, SQLAlchemyObjectType): +class Department(SQLAlchemyObjectType): class Meta: model = DepartmentModel + interfaces = (relay.Node, ) -class Employee(SQLAlchemyNode, SQLAlchemyObjectType): +class Employee(SQLAlchemyObjectType): class Meta: model = EmployeeModel + interfaces = (relay.Node, ) -class Role(SQLAlchemyNode, SQLAlchemyObjectType): +class Role(SQLAlchemyObjectType): class Meta: model = RoleModel + interfaces = (relay.Node, ) class Query(graphene.ObjectType): - node = SQLAlchemyNode.Field() + node = relay.Node.Field() all_employees = SQLAlchemyConnectionField(Employee) all_roles = SQLAlchemyConnectionField(Role) role = graphene.Field(Role) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py b/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py index 80017886..bdb21103 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py @@ -1,10 +1,9 @@ from .types import ( SQLAlchemyObjectType, - SQLAlchemyNode ) from .fields import ( SQLAlchemyConnectionField ) -__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyNode', +__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyConnectionField'] diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py index 8a8cd84e..9f4d96bc 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py @@ -3,8 +3,8 @@ from sqlalchemy import types from sqlalchemy.orm import interfaces from sqlalchemy.dialects import postgresql -from graphene import Enum, ID, Boolean, Float, Int, String, List, Field -from graphene.relay import Node +from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic +from graphene.relay import is_node from graphene.types.json import JSONString from .fields import SQLAlchemyConnectionField @@ -18,16 +18,20 @@ except ImportError: def convert_sqlalchemy_relationship(relationship, registry): direction = relationship.direction model = relationship.mapper.entity - _type = registry.get_type_for_model(model) - if not _type: - return None - if direction == interfaces.MANYTOONE: - return Field(_type) - elif (direction == interfaces.ONETOMANY or - direction == interfaces.MANYTOMANY): - if issubclass(_type, Node): - return SQLAlchemyConnectionField(_type) - return List(_type) + + def dynamic_type(): + _type = registry.get_type_for_model(model) + if not _type: + return None + if direction == interfaces.MANYTOONE: + return Field(_type) + elif (direction == interfaces.ONETOMANY or + direction == interfaces.MANYTOMANY): + if is_node(_type): + return SQLAlchemyConnectionField(_type) + return Field(List(_type)) + + return Dynamic(dynamic_type) def convert_sqlalchemy_column(column, registry=None): diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py index b6402e91..69f154b0 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py @@ -1,3 +1,4 @@ +from functools import partial from sqlalchemy.orm.query import Query from graphene.relay import ConnectionField @@ -9,17 +10,13 @@ class SQLAlchemyConnectionField(ConnectionField): @property def model(self): - return self.connection._meta.node._meta.model - - def get_query(self, context): - return get_query(self.model, context) - - def default_resolver(self, root, args, context, info): - return getattr(root, self.source or self.attname, self.get_query(context)) + return self.type._meta.node._meta.model @staticmethod - def connection_resolver(resolver, connection, root, args, context, info): + def connection_resolver(resolver, connection, model, root, args, context, info): iterable = resolver(root, args, context, info) + if iterable is None: + iterable = get_query(model, context) if isinstance(iterable, Query): _len = iterable.count() else: @@ -33,3 +30,6 @@ class SQLAlchemyConnectionField(ConnectionField): connection_type=connection, edge_type=connection.Edge, ) + + def get_resolver(self, parent_resolver): + return partial(self.connection_resolver, parent_resolver, self.type, self.model) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/registry.py b/graphene-sqlalchemy/graphene_sqlalchemy/registry.py index 5062f1ab..adb56e82 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/registry.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/registry.py @@ -7,10 +7,10 @@ class Registry(object): from .types import SQLAlchemyObjectType assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__) assert cls._meta.registry == self, 'Registry for a Model have to match.' - assert self.get_type_for_model(cls._meta.model) in [None, cls], ( - 'SQLAlchemy model "{}" already associated with ' - 'another type "{}".' - ).format(cls._meta.model, self._registry[cls._meta.model]) + # assert self.get_type_for_model(cls._meta.model) in [None, cls], ( + # 'SQLAlchemy model "{}" already associated with ' + # 'another type "{}".' + # ).format(cls._meta.model, self._registry[cls._meta.model]) self._registry[cls._meta.model] = cls def get_type_for_model(self, model): diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py index 8a6044af..eb23f856 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py @@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType from sqlalchemy.dialects import postgresql import graphene +from graphene.relay import Node from graphene.types.json import JSONString from ..converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) from ..fields import SQLAlchemyConnectionField -from ..types import SQLAlchemyObjectType, SQLAlchemyNode +from ..types import SQLAlchemyObjectType from ..registry import Registry from .models import Article, Pet, Reporter @@ -19,7 +20,7 @@ def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs): column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs) graphene_type = convert_sqlalchemy_column(column) assert isinstance(graphene_type, graphene_field) - field = graphene_type.as_field() + field = graphene_type.Field() assert field.description == 'Custom Help Text' return field @@ -101,16 +102,17 @@ def test_should_choice_convert_enum(): Table('translatedmodel', Base.metadata, column) graphene_type = convert_sqlalchemy_column(column) assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE' - assert graphene_type._meta.graphql_type.description == 'Language' + assert graphene_type._meta.name == 'TRANSLATEDMODEL_LANGUAGE' + assert graphene_type._meta.description == 'Language' assert graphene_type._meta.enum.__members__['es'].value == 'Spanish' assert graphene_type._meta.enum.__members__['en'].value == 'English' def test_should_manytomany_convert_connectionorlist(): registry = Registry() - graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry) - assert not graphene_type + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + assert isinstance(dynamic_field, graphene.Dynamic) + assert not dynamic_field.get_type() def test_should_manytomany_convert_connectionorlist_list(): @@ -118,26 +120,30 @@ def test_should_manytomany_convert_connectionorlist_list(): class Meta: model = Pet - graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) - assert isinstance(graphene_type, graphene.List) - assert graphene_type.of_type == A._meta.graphql_type + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() + assert isinstance(graphene_type, graphene.Field) + assert isinstance(graphene_type.type, graphene.List) + assert graphene_type.type.of_type == A def test_should_manytomany_convert_connectionorlist_connection(): - class A(SQLAlchemyNode, SQLAlchemyObjectType): + class A(SQLAlchemyObjectType): class Meta: model = Pet + interfaces = (Node, ) - graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) - assert isinstance(graphene_type, SQLAlchemyConnectionField) - - + dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + assert isinstance(dynamic_field, graphene.Dynamic) + assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) def test_should_manytoone_convert_connectionorlist(): registry = Registry() - graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry) - assert not graphene_type + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) + assert isinstance(dynamic_field, graphene.Dynamic) + assert not dynamic_field.get_type() def test_should_manytoone_convert_connectionorlist_list(): @@ -145,19 +151,24 @@ def test_should_manytoone_convert_connectionorlist_list(): class Meta: model = Reporter - graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) - assert graphene_type.type == A._meta.graphql_type + assert graphene_type.type == A def test_should_manytoone_convert_connectionorlist_connection(): - class A(SQLAlchemyNode, SQLAlchemyObjectType): + class A(SQLAlchemyObjectType): class Meta: model = Reporter + interfaces = (Node, ) - graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + assert isinstance(dynamic_field, graphene.Dynamic) + graphene_type = dynamic_field.get_type() assert isinstance(graphene_type, graphene.Field) - assert graphene_type.type == A._meta.graphql_type + assert graphene_type.type == A def test_should_postgresql_uuid_convert(): diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py index c9640c49..a7380cd8 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py @@ -3,8 +3,8 @@ from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session, sessionmaker import graphene -from graphene import relay -from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) +from graphene.relay import Node +from ..types import SQLAlchemyObjectType from ..fields import SQLAlchemyConnectionField from .models import Article, Base, Editor, Reporter @@ -93,10 +93,11 @@ def test_should_query_well(session): def test_should_node(session): setup_fixtures(session) - class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType): + class ReporterNode(SQLAlchemyObjectType): class Meta: model = Reporter + interfaces = (Node, ) @classmethod def get_node(cls, id, info): @@ -105,17 +106,18 @@ def test_should_node(session): def resolve_articles(self, *args, **kwargs): return [Article(headline='Hi!')] - class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType): + class ArticleNode(SQLAlchemyObjectType): class Meta: model = Article + interfaces = (Node, ) # @classmethod # def get_node(cls, id, info): # return Article(id=1, headline='Article node') class Query(graphene.ObjectType): - node = SQLAlchemyNode.Field() + node = Node.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNode) @@ -194,13 +196,14 @@ def test_should_node(session): def test_should_custom_identifier(session): setup_fixtures(session) - class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType): + class EditorNode(SQLAlchemyObjectType): class Meta: model = Editor + interfaces = (Node, ) class Query(graphene.ObjectType): - node = SQLAlchemyNode.Field() + node = Node.Field() all_editors = SQLAlchemyConnectionField(EditorNode) query = ''' diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py index f18f1f94..24bfb7f4 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py @@ -28,7 +28,7 @@ def test_should_map_fields_correctly(): model = Reporter registry = Registry() - assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email'] + assert ReporterType2._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles'] def test_should_map_only_few_fields(): @@ -36,5 +36,5 @@ def test_should_map_only_few_fields(): class Meta: model = Reporter - only = ('id', 'email') - assert Reporter2._meta.get_fields().keys() == ['id', 'email'] + only_fields = ('id', 'email') + assert Reporter2._meta.fields.keys() == ['id', 'email'] diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py index 7a387471..527a9fb4 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py @@ -1,10 +1,10 @@ from graphql.type import GraphQLObjectType, GraphQLInterfaceType -from graphql.type.definition import GraphQLFieldDefinition from graphql import GraphQLInt from pytest import raises -from graphene import Schema -from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) +from graphene import Schema, Interface, ObjectType +from graphene.relay import Node, is_node +from ..types import SQLAlchemyObjectType from ..registry import Registry from graphene import Field, Int @@ -14,6 +14,7 @@ from .models import Article, Reporter registry = Registry() + class Character(SQLAlchemyObjectType): '''Character description''' class Meta: @@ -21,22 +22,21 @@ class Character(SQLAlchemyObjectType): registry = registry -class Human(SQLAlchemyNode, SQLAlchemyObjectType): +class Human(SQLAlchemyObjectType): '''Human description''' pub_date = Int() class Meta: model = Article - exclude = ('id', ) + exclude_fields = ('id', ) registry = registry - - - + interfaces = (Node, ) def test_sqlalchemy_interface(): - assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType) + assert issubclass(Node, Interface) + assert issubclass(Node, Node) # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) @@ -47,14 +47,13 @@ def test_sqlalchemy_interface(): def test_objecttype_registered(): - object_type = Character._meta.graphql_type - assert isinstance(object_type, GraphQLObjectType) + assert issubclass(Character, ObjectType) assert Character._meta.model == Reporter - assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email'] + assert Character._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles'] # def test_sqlalchemynode_idfield(): -# idfield = SQLAlchemyNode._meta.fields_map['id'] +# idfield = Node._meta.fields_map['id'] # assert isinstance(idfield, GlobalIDField) @@ -64,14 +63,12 @@ def test_objecttype_registered(): def test_node_replacedfield(): - idfield = Human._meta.graphql_type.get_fields()['pubDate'] - assert isinstance(idfield, GraphQLFieldDefinition) - assert idfield.type == GraphQLInt + idfield = Human._meta.fields['pub_date'] + assert isinstance(idfield, Field) + assert idfield.type == Int def test_object_type(): - object_type = Human._meta.graphql_type - object_type.get_fields() - assert isinstance(object_type, GraphQLObjectType) - assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId'] - assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces() + assert issubclass(Human, ObjectType) + assert Human._meta.fields.keys() == ['id', 'pub_date', 'headline', 'reporter_id', 'reporter'] + assert is_node(Human) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/types.py b/graphene-sqlalchemy/graphene_sqlalchemy/types.py index 186a897d..fe3b21da 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/types.py @@ -1,138 +1,112 @@ +from collections import OrderedDict import six from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.orm.exc import NoResultFound from graphene import ObjectType -from graphene.relay import Node +from graphene.relay import is_node from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) from .utils import is_mapped -from functools import partial - - -from graphene import Field, Interface +from graphene.types.objecttype import ObjectTypeMeta from graphene.types.options import Options -from graphene.types.objecttype import attrs_without_fields, get_interfaces - from .registry import Registry, get_global_registry -from .utils import get_query from graphene.utils.is_base_type import is_base_type -from graphene.utils.copy_fields import copy_fields -from graphene.utils.get_graphql_type import get_graphql_type -from graphene.utils.get_fields import get_fields -from graphene.utils.as_field import as_field -from graphene.generators import generate_objecttype +from graphene.types.utils import get_fields_in_type +from .utils import get_query -class SQLAlchemyObjectTypeMeta(type(ObjectType)): - def _construct_fields(cls, fields, options): - only_fields = cls._meta.only - exclude_fields = cls._meta.exclude +class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): + def _construct_fields(cls, all_fields, options): + only_fields = cls._meta.only_fields + exclude_fields = cls._meta.exclude_fields inspected_model = sqlalchemyinspect(cls._meta.model) - # Get all the columns for the relationships on the model - for relationship in inspected_model.relationships: - is_not_in_only = only_fields and relationship.key not in only_fields - is_already_created = relationship.key in fields - is_excluded = relationship.key in exclude_fields or is_already_created - if is_not_in_only or is_excluded: - # We skip this field if we specify only_fields and is not - # in there. Or when we excldue this field in exclude_fields - continue - converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) - if not converted_relationship: - continue - name = relationship.key - fields[name] = as_field(converted_relationship) + fields = OrderedDict() for name, column in inspected_model.columns.items(): is_not_in_only = only_fields and name not in only_fields - is_already_created = name in fields + is_already_created = name in all_fields is_excluded = name in exclude_fields or is_already_created if is_not_in_only or is_excluded: # We skip this field if we specify only_fields and is not # in there. Or when we excldue this field in exclude_fields continue converted_column = convert_sqlalchemy_column(column, options.registry) - if not converted_column: - continue - fields[name] = as_field(converted_column) + fields[name] = converted_column - fields = copy_fields(Field, fields, parent=cls) + # Get all the columns for the relationships on the model + for relationship in inspected_model.relationships: + is_not_in_only = only_fields and relationship.key not in only_fields + is_already_created = relationship.key in all_fields + is_excluded = relationship.key in exclude_fields or is_already_created + if is_not_in_only or is_excluded: + # We skip this field if we specify only_fields and is not + # in there. Or when we excldue this field in exclude_fields + continue + converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) + name = relationship.key + fields[name] = converted_relationship return fields @staticmethod - def _create_objecttype(cls, name, bases, attrs): - # super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__ - super_new = type.__new__ - + def __new__(cls, name, bases, attrs): # Also ensure initialization is only performed for subclasses of Model # (excluding Model class itself). if not is_base_type(bases, SQLAlchemyObjectTypeMeta): - return super_new(cls, name, bases, attrs) + return type.__new__(cls, name, bases, attrs) options = Options( attrs.pop('Meta', None), - name=None, - description=None, + name=name, + description=attrs.pop('__doc__', None), model=None, - fields=(), - exclude=(), - only=(), + fields=None, + only_fields=(), + exclude_fields=(), + id='id', interfaces=(), registry=None ) if not options.registry: options.registry = get_global_registry() - assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry) - assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model) + assert isinstance(options.registry, Registry), ( + 'The attribute registry in {}.Meta needs to be an' + ' instance of Registry, received "{}".' + ).format(name, options.registry) + assert is_mapped(options.model), ( + 'You need to pass a valid SQLAlchemy Model in ' + '{}.Meta, received "{}".' + ).format(name, options.model) - interfaces = tuple(options.interfaces) - fields = get_fields(ObjectType, attrs, bases, interfaces) - attrs = attrs_without_fields(attrs, fields) - cls = super_new(cls, name, bases, dict(attrs, _meta=options)) - base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) - options.get_fields = partial(cls._construct_fields, fields, options) - options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces)) + cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) - options.graphql_type = generate_objecttype(cls) + options.registry.register(cls) - if issubclass(cls, SQLAlchemyObjectType): - options.registry.register(cls) + options.sqlalchemy_fields = get_fields_in_type( + ObjectType, + cls._construct_fields(options.fields, options) + ) + options.fields.update(options.sqlalchemy_fields) return cls class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)): - is_type_of = None + @classmethod + def is_type_of(cls, root, context, info): + if isinstance(root, cls): + return True + if not is_mapped(type(root)): + raise Exception(( + 'Received incompatible instance "{}".' + ).format(root)) + return type(root) == cls._meta.model - -class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)): - - @staticmethod - def _get_interface_options(meta): - return Options( - meta, - name=None, - description=None, - graphql_type=None, - registry=False - ) - - @staticmethod - def _create_interface(cls, name, bases, attrs): - cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs) - if not cls._meta.registry: - cls._meta.registry = get_global_registry() - assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name) - return cls - - -class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)): @classmethod def get_node(cls, id, context, info): try: @@ -142,16 +116,8 @@ class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)): except NoResultFound: return None - @classmethod - def resolve_id(cls, root, args, context, info): - return root.__mapper__.primary_key_from_instance(root)[0] - - @classmethod - def resolve_type(cls, type_instance, context, info): - # We get the model from the _meta in the SQLAlchemy class/instance - model = type(type_instance) - graphene_type = cls._meta.registry.get_type_for_model(model) - if graphene_type: - return get_graphql_type(graphene_type) - - raise Exception("Type not found for model \"{}\"".format(model)) + def resolve_id(root, args, context, info): + graphene_type = info.parent_type.graphene_type + if is_node(graphene_type): + return root.__mapper__.primary_key_from_instance(root)[0] + return getattr(root, graphene_type._meta.id, None)