From af4c63512cc8c3d50162c9079d5c1cb8ebaaca72 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Fri, 22 Jul 2016 20:18:23 -0700 Subject: [PATCH] First working version of graphene-sqlalchemy --- graphene-django/graphene_django/converter.py | 1 - .../graphene_sqlalchemy/__init__.py | 4 +- .../graphene_sqlalchemy/converter.py | 37 ++-- .../graphene_sqlalchemy/fields.py | 82 +++----- .../graphene_sqlalchemy/options.py | 24 --- .../graphene_sqlalchemy/registry.py | 28 +++ .../tests/test_converter.py | 76 +++++-- .../graphene_sqlalchemy/tests/test_query.py | 29 +-- .../graphene_sqlalchemy/tests/test_schema.py | 23 +- .../graphene_sqlalchemy/tests/test_types.py | 89 +++----- .../graphene_sqlalchemy/tests/test_utils.py | 9 +- .../graphene_sqlalchemy/types.py | 199 ++++++++++-------- .../graphene_sqlalchemy/utils.py | 37 +--- graphene/relay/connection.py | 7 +- graphene/relay/node.py | 6 +- 15 files changed, 318 insertions(+), 333 deletions(-) delete mode 100644 graphene-sqlalchemy/graphene_sqlalchemy/options.py create mode 100644 graphene-sqlalchemy/graphene_sqlalchemy/registry.py diff --git a/graphene-django/graphene_django/converter.py b/graphene-django/graphene_django/converter.py index 2078639b..625897e6 100644 --- a/graphene-django/graphene_django/converter.py +++ b/graphene-django/graphene_django/converter.py @@ -4,7 +4,6 @@ from django.utils.encoding import force_text from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull from graphene.types.json import JSONString from graphene.types.datetime import DateTime -from graphene.types.json import JSONString from graphene.utils.str_converters import to_const from graphene.relay import Node diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py b/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py index 10bf8f5e..80017886 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/__init__.py @@ -1,8 +1,8 @@ -from graphene.contrib.sqlalchemy.types import ( +from .types import ( SQLAlchemyObjectType, SQLAlchemyNode ) -from graphene.contrib.sqlalchemy.fields import ( +from .fields import ( SQLAlchemyConnectionField ) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py index 1bf7e37f..8a8cd84e 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/converter.py @@ -3,9 +3,10 @@ from sqlalchemy import types from sqlalchemy.orm import interfaces from sqlalchemy.dialects import postgresql -from graphene import Enum, ID, Boolean, Float, Int, String, List +from graphene import Enum, ID, Boolean, Float, Int, String, List, Field +from graphene.relay import Node from graphene.types.json import JSONString -from .fields import ConnectionOrListField, SQLAlchemyModelField +from .fields import SQLAlchemyConnectionField try: from sqlalchemy_utils.types.choice import ChoiceType @@ -14,23 +15,27 @@ except ImportError: pass -def convert_sqlalchemy_relationship(relationship): +def convert_sqlalchemy_relationship(relationship, registry): direction = relationship.direction model = relationship.mapper.entity - model_field = SQLAlchemyModelField(model, description=relationship.doc) + _type = registry.get_type_for_model(model) + if not _type: + return None if direction == interfaces.MANYTOONE: - return model_field + return Field(_type) elif (direction == interfaces.ONETOMANY or direction == interfaces.MANYTOMANY): - return ConnectionOrListField(model_field) + if issubclass(_type, Node): + return SQLAlchemyConnectionField(_type) + return List(_type) -def convert_sqlalchemy_column(column): - return convert_sqlalchemy_type(getattr(column, 'type', None), column) +def convert_sqlalchemy_column(column, registry=None): + return convert_sqlalchemy_type(getattr(column, 'type', None), column, registry) @singledispatch -def convert_sqlalchemy_type(type, column): +def convert_sqlalchemy_type(type, column, registry=None): raise Exception( "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) @@ -45,14 +50,14 @@ def convert_sqlalchemy_type(type, column): @convert_sqlalchemy_type.register(types.Enum) @convert_sqlalchemy_type.register(postgresql.ENUM) @convert_sqlalchemy_type.register(postgresql.UUID) -def convert_column_to_string(type, column): +def convert_column_to_string(type, column, registry=None): return String(description=column.doc) @convert_sqlalchemy_type.register(types.SmallInteger) @convert_sqlalchemy_type.register(types.BigInteger) @convert_sqlalchemy_type.register(types.Integer) -def convert_column_to_int_or_id(type, column): +def convert_column_to_int_or_id(type, column, registry=None): if column.primary_key: return ID(description=column.doc) else: @@ -60,24 +65,24 @@ def convert_column_to_int_or_id(type, column): @convert_sqlalchemy_type.register(types.Boolean) -def convert_column_to_boolean(type, column): +def convert_column_to_boolean(type, column, registry=None): return Boolean(description=column.doc) @convert_sqlalchemy_type.register(types.Float) @convert_sqlalchemy_type.register(types.Numeric) -def convert_column_to_float(type, column): +def convert_column_to_float(type, column, registry=None): return Float(description=column.doc) @convert_sqlalchemy_type.register(ChoiceType) -def convert_column_to_enum(type, column): +def convert_column_to_enum(type, column, registry=None): name = '{}_{}'.format(column.table.name, column.name).upper() return Enum(name, type.choices, description=column.doc) @convert_sqlalchemy_type.register(postgresql.ARRAY) -def convert_postgres_array_to_list(type, column): +def convert_postgres_array_to_list(type, column, registry=None): graphene_type = convert_sqlalchemy_type(column.type.item_type, column) return List(graphene_type, description=column.doc) @@ -85,5 +90,5 @@ def convert_postgres_array_to_list(type, column): @convert_sqlalchemy_type.register(postgresql.HSTORE) @convert_sqlalchemy_type.register(postgresql.JSON) @convert_sqlalchemy_type.register(postgresql.JSONB) -def convert_json_to_string(type, column): +def convert_json_to_string(type, column, registry=None): return JSONString(description=column.doc) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py index 598cd341..b6402e91 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/fields.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/fields.py @@ -1,69 +1,35 @@ -from ...core.exceptions import SkipField -from ...core.fields import Field -from ...core.types.base import FieldType -from ...core.types.definitions import List -from ...relay import ConnectionField -from ...relay.utils import is_node -from .utils import get_query, get_type_for_model, maybe_query +from sqlalchemy.orm.query import Query - -class DefaultQuery(object): - pass +from graphene.relay import ConnectionField +from graphql_relay.connection.arrayconnection import connection_from_list_slice +from .utils import get_query class SQLAlchemyConnectionField(ConnectionField): - def __init__(self, *args, **kwargs): - kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery) - return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) - @property def model(self): - return self.type._meta.model + return self.connection._meta.node._meta.model - def from_list(self, connection_type, resolved, args, context, info): - if resolved is DefaultQuery: - resolved = get_query(self.model, info) - query = maybe_query(resolved) - return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info) + def get_query(self, context): + return get_query(self.model, context) + def default_resolver(self, root, args, context, info): + return getattr(root, self.source or self.attname, self.get_query(context)) -class ConnectionOrListField(Field): - - def internal_type(self, schema): - model_field = self.type - field_object_type = model_field.get_object_type(schema) - if not field_object_type: - raise SkipField() - if is_node(field_object_type): - field = SQLAlchemyConnectionField(field_object_type) + @staticmethod + def connection_resolver(resolver, connection, root, args, context, info): + iterable = resolver(root, args, context, info) + if isinstance(iterable, Query): + _len = iterable.count() else: - field = Field(List(field_object_type)) - field.contribute_to_class(self.object_type, self.attname) - return schema.T(field) - - -class SQLAlchemyModelField(FieldType): - - def __init__(self, model, *args, **kwargs): - self.model = model - super(SQLAlchemyModelField, self).__init__(*args, **kwargs) - - def internal_type(self, schema): - _type = self.get_object_type(schema) - if not _type and self.parent._meta.only_fields: - raise Exception( - "Table %r is not accessible by the schema. " - "You can either register the type manually " - "using @schema.register. " - "Or disable the field in %s" % ( - self.model, - self.parent, - ) - ) - if not _type: - raise SkipField() - return schema.T(_type) - - def get_object_type(self, schema): - return get_type_for_model(schema, self.model) + _len = len(iterable) + return connection_from_list_slice( + iterable, + args, + slice_start=0, + list_length=_len, + list_slice_length=_len, + connection_type=connection, + edge_type=connection.Edge, + ) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/options.py b/graphene-sqlalchemy/graphene_sqlalchemy/options.py deleted file mode 100644 index 44886287..00000000 --- a/graphene-sqlalchemy/graphene_sqlalchemy/options.py +++ /dev/null @@ -1,24 +0,0 @@ -from ...core.classtypes.objecttype import ObjectTypeOptions -from ...relay.types import Node -from ...relay.utils import is_node - -VALID_ATTRS = ('model', 'only_fields', 'exclude_fields', 'identifier') - - -class SQLAlchemyOptions(ObjectTypeOptions): - - def __init__(self, *args, **kwargs): - super(SQLAlchemyOptions, self).__init__(*args, **kwargs) - self.model = None - self.identifier = "id" - self.valid_attrs += VALID_ATTRS - self.only_fields = None - self.exclude_fields = [] - self.filter_fields = None - self.filter_order_by = None - - def contribute_to_class(self, cls, name): - super(SQLAlchemyOptions, self).contribute_to_class(cls, name) - if is_node(cls): - self.exclude_fields = list(self.exclude_fields) + ['id'] - self.interfaces.append(Node) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/registry.py b/graphene-sqlalchemy/graphene_sqlalchemy/registry.py new file mode 100644 index 00000000..56492965 --- /dev/null +++ b/graphene-sqlalchemy/graphene_sqlalchemy/registry.py @@ -0,0 +1,28 @@ +class Registry(object): + def __init__(self): + self._registry = {} + self._registry_models = {} + + def register(self, cls): + from .types import SQLAlchemyObjectType + assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__) + assert cls._meta.registry == self, 'Registry for a Model have to match.' + self._registry[cls._meta.model] = cls + + def get_type_for_model(self, model): + return self._registry.get(model) + + +registry = None + + +def get_global_registry(): + global registry + if not registry: + registry = Registry() + return registry + + +def reset_global_registry(): + global registry + registry = None diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py index 521911ee..8a6044af 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_converter.py @@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType from sqlalchemy.dialects import postgresql import graphene -from graphene.core.types.custom_scalars import JSONString -from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, - convert_sqlalchemy_relationship) -from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, - SQLAlchemyModelField) +from graphene.types.json import JSONString +from ..converter import (convert_sqlalchemy_column, + convert_sqlalchemy_relationship) +from ..fields import SQLAlchemyConnectionField +from ..types import SQLAlchemyObjectType, SQLAlchemyNode +from ..registry import Registry from .models import Article, Pet, Reporter @@ -100,30 +101,63 @@ def test_should_choice_convert_enum(): Table('translatedmodel', Base.metadata, column) graphene_type = convert_sqlalchemy_column(column) assert issubclass(graphene_type, graphene.Enum) - assert graphene_type._meta.type_name == 'TRANSLATEDMODEL_LANGUAGE' - assert graphene_type._meta.description == 'Language' - assert graphene_type.__enum__.__members__['es'].value == 'Spanish' - assert graphene_type.__enum__.__members__['en'].value == 'English' + assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE' + assert graphene_type._meta.graphql_type.description == 'Language' + assert graphene_type._meta.enum.__members__['es'].value == 'Spanish' + assert graphene_type._meta.enum.__members__['en'].value == 'English' def test_should_manytomany_convert_connectionorlist(): - graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property) - assert isinstance(graphene_type, ConnectionOrListField) - assert isinstance(graphene_type.type, SQLAlchemyModelField) - assert graphene_type.type.model == Pet + registry = Registry() + graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry) + assert not graphene_type + + +def test_should_manytomany_convert_connectionorlist_list(): + class A(SQLAlchemyObjectType): + class Meta: + model = Pet + + graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + assert isinstance(graphene_type, graphene.List) + assert graphene_type.of_type == A._meta.graphql_type + + +def test_should_manytomany_convert_connectionorlist_connection(): + class A(SQLAlchemyNode, SQLAlchemyObjectType): + class Meta: + model = Pet + + graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) + assert isinstance(graphene_type, SQLAlchemyConnectionField) + + def test_should_manytoone_convert_connectionorlist(): - field = convert_sqlalchemy_relationship(Article.reporter.property) - assert isinstance(field, SQLAlchemyModelField) - assert field.model == Reporter + registry = Registry() + graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry) + assert not graphene_type -def test_should_onetomany_convert_model(): - graphene_type = convert_sqlalchemy_relationship(Reporter.articles.property) - assert isinstance(graphene_type, ConnectionOrListField) - assert isinstance(graphene_type.type, SQLAlchemyModelField) - assert graphene_type.type.model == Article +def test_should_manytoone_convert_connectionorlist_list(): + class A(SQLAlchemyObjectType): + class Meta: + model = Reporter + + graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + assert isinstance(graphene_type, graphene.Field) + assert graphene_type.type == A._meta.graphql_type + + +def test_should_manytoone_convert_connectionorlist_connection(): + class A(SQLAlchemyNode, SQLAlchemyObjectType): + class Meta: + model = Reporter + + graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) + assert isinstance(graphene_type, graphene.Field) + assert graphene_type.type == A._meta.graphql_type def test_should_postgresql_uuid_convert(): diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py index da8f8e11..c9640c49 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_query.py @@ -4,8 +4,8 @@ from sqlalchemy.orm import scoped_session, sessionmaker import graphene from graphene import relay -from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, - SQLAlchemyNode, SQLAlchemyObjectType) +from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) +from ..fields import SQLAlchemyConnectionField from .models import Article, Base, Editor, Reporter @@ -52,7 +52,7 @@ def test_should_query_well(session): class Query(graphene.ObjectType): reporter = graphene.Field(ReporterType) - reporters = ReporterType.List() + reporters = graphene.List(ReporterType) def resolve_reporter(self, *args, **kwargs): return session.query(Reporter).first() @@ -93,7 +93,7 @@ def test_should_query_well(session): def test_should_node(session): setup_fixtures(session) - class ReporterNode(SQLAlchemyNode): + class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType): class Meta: model = Reporter @@ -105,7 +105,7 @@ def test_should_node(session): def resolve_articles(self, *args, **kwargs): return [Article(headline='Hi!')] - class ArticleNode(SQLAlchemyNode): + class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType): class Meta: model = Article @@ -115,7 +115,7 @@ def test_should_node(session): # return Article(id=1, headline='Article node') class Query(graphene.ObjectType): - node = relay.NodeField() + node = SQLAlchemyNode.Field() reporter = graphene.Field(ReporterNode) article = graphene.Field(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNode) @@ -185,8 +185,8 @@ def test_should_node(session): 'headline': 'Hi!' } } - schema = graphene.Schema(query=Query, session=session) - result = schema.execute(query) + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data == expected @@ -194,14 +194,13 @@ def test_should_node(session): def test_should_custom_identifier(session): setup_fixtures(session) - class EditorNode(SQLAlchemyNode): + class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType): class Meta: model = Editor - identifier = "editor_id" class Query(graphene.ObjectType): - node = relay.NodeField(EditorNode) + node = SQLAlchemyNode.Field() all_editors = SQLAlchemyConnectionField(EditorNode) query = ''' @@ -215,7 +214,9 @@ def test_should_custom_identifier(session): } }, node(id: "RWRpdG9yTm9kZTox") { - name + ...on EditorNode { + name + } } } ''' @@ -233,7 +234,7 @@ def test_should_custom_identifier(session): } } - schema = graphene.Schema(query=Query, session=session) - result = schema.execute(query) + schema = graphene.Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data == expected diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py index 090b2e18..f18f1f94 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_schema.py @@ -1,25 +1,24 @@ from py.test import raises -from graphene.contrib.sqlalchemy import SQLAlchemyObjectType -from tests.utils import assert_equal_lists +from ..types import SQLAlchemyObjectType from .models import Reporter +from ..registry import Registry def test_should_raise_if_no_model(): with raises(Exception) as excinfo: class Character1(SQLAlchemyObjectType): pass - assert 'model in the Meta' in str(excinfo.value) + assert 'valid SQLAlchemy Model' in str(excinfo.value) def test_should_raise_if_model_is_invalid(): with raises(Exception) as excinfo: class Character2(SQLAlchemyObjectType): - class Meta: model = 1 - assert 'not a SQLAlchemy model' in str(excinfo.value) + assert 'valid SQLAlchemy Model' in str(excinfo.value) def test_should_map_fields_correctly(): @@ -27,10 +26,9 @@ def test_should_map_fields_correctly(): class Meta: model = Reporter - assert_equal_lists( - ReporterType2._meta.fields_map.keys(), - ['articles', 'first_name', 'last_name', 'email', 'pets', 'id'] - ) + registry = Registry() + + assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email'] def test_should_map_only_few_fields(): @@ -38,8 +36,5 @@ def test_should_map_only_few_fields(): class Meta: model = Reporter - only_fields = ('id', 'email') - assert_equal_lists( - Reporter2._meta.fields_map.keys(), - ['id', 'email'] - ) + only = ('id', 'email') + assert Reporter2._meta.get_fields().keys() == ['id', 'email'] diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py index 378411ae..7a387471 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_types.py @@ -1,38 +1,42 @@ -from graphql.type import GraphQLObjectType +from graphql.type import GraphQLObjectType, GraphQLInterfaceType +from graphql.type.definition import GraphQLFieldDefinition +from graphql import GraphQLInt from pytest import raises from graphene import Schema -from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode, - SQLAlchemyObjectType) -from graphene.core.fields import Field -from graphene.core.types.scalars import Int -from graphene.relay.fields import GlobalIDField -from tests.utils import assert_equal_lists +from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) +from ..registry import Registry + +from graphene import Field, Int +# from tests.utils import assert_equal_lists from .models import Article, Reporter -schema = Schema() - +registry = Registry() class Character(SQLAlchemyObjectType): '''Character description''' class Meta: model = Reporter + registry = registry -@schema.register -class Human(SQLAlchemyNode): +class Human(SQLAlchemyNode, SQLAlchemyObjectType): '''Human description''' pub_date = Int() class Meta: model = Article - exclude_fields = ('id', ) + exclude = ('id', ) + registry = registry + + + def test_sqlalchemy_interface(): - assert SQLAlchemyNode._meta.interface is True + assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType) # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) @@ -43,60 +47,31 @@ def test_sqlalchemy_interface(): def test_objecttype_registered(): - object_type = schema.T(Character) + object_type = Character._meta.graphql_type assert isinstance(object_type, GraphQLObjectType) assert Character._meta.model == Reporter - assert_equal_lists( - object_type.get_fields().keys(), - ['articles', 'firstName', 'lastName', 'email', 'id'] - ) + assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email'] -def test_sqlalchemynode_idfield(): - idfield = SQLAlchemyNode._meta.fields_map['id'] - assert isinstance(idfield, GlobalIDField) +# def test_sqlalchemynode_idfield(): +# idfield = SQLAlchemyNode._meta.fields_map['id'] +# assert isinstance(idfield, GlobalIDField) -def test_node_idfield(): - idfield = Human._meta.fields_map['id'] - assert isinstance(idfield, GlobalIDField) +# def test_node_idfield(): +# idfield = Human._meta.fields_map['id'] +# assert isinstance(idfield, GlobalIDField) def test_node_replacedfield(): - idfield = Human._meta.fields_map['pub_date'] - assert isinstance(idfield, Field) - assert schema.T(idfield).type == schema.T(Int()) - - -def test_interface_objecttype_init_none(): - h = Human() - assert h._root is None - - -def test_interface_objecttype_init_good(): - instance = Article() - h = Human(instance) - assert h._root == instance - - -def test_interface_objecttype_init_unexpected(): - with raises(AssertionError) as excinfo: - Human(object()) - assert str(excinfo.value) == "Human received a non-compatible instance (object) when expecting Article" + idfield = Human._meta.graphql_type.get_fields()['pubDate'] + assert isinstance(idfield, GraphQLFieldDefinition) + assert idfield.type == GraphQLInt def test_object_type(): - object_type = schema.T(Human) - Human._meta.fields_map - assert Human._meta.interface is False + object_type = Human._meta.graphql_type + object_type.get_fields() assert isinstance(object_type, GraphQLObjectType) - assert_equal_lists( - object_type.get_fields().keys(), - ['headline', 'id', 'reporter', 'reporterId', 'pubDate'] - ) - assert schema.T(SQLAlchemyNode) in object_type.get_interfaces() - - -def test_node_notinterface(): - assert Human._meta.interface is False - assert SQLAlchemyNode in Human._meta.interfaces + assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId'] + assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces() diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_utils.py b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_utils.py index 2925f016..484b9f6a 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_utils.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/tests/test_utils.py @@ -5,13 +5,12 @@ from ..utils import get_session def test_get_session(): session = 'My SQLAlchemy session' - schema = Schema(session=session) class Query(ObjectType): x = String() - def resolve_x(self, args, info): - return get_session(info) + def resolve_x(self, args, context, info): + return get_session(context) query = ''' query ReporterQuery { @@ -19,7 +18,7 @@ def test_get_session(): } ''' - schema = Schema(query=Query, session=session) - result = schema.execute(query) + schema = Schema(query=Query) + result = schema.execute(query, context_value={'session': session}) assert not result.errors assert result.data['x'] == session diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/types.py b/graphene-sqlalchemy/graphene_sqlalchemy/types.py index 20202ab7..e6c8ef85 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/types.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/types.py @@ -1,125 +1,158 @@ -import inspect - import six from sqlalchemy.inspection import inspect as sqlalchemyinspect from sqlalchemy.orm.exc import NoResultFound -from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta -from ...relay.types import Connection, Node, NodeMeta +from graphene import ObjectType +from graphene.relay import Node from .converter import (convert_sqlalchemy_column, convert_sqlalchemy_relationship) -from .options import SQLAlchemyOptions -from .utils import get_query, is_mapped +from .utils import is_mapped + +from functools import partial -class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): - options_class = SQLAlchemyOptions +from graphene import Field, Interface +from graphene.types.options import Options +from graphene.types.objecttype import attrs_without_fields, get_interfaces - def construct_fields(cls): - only_fields = cls._meta.only_fields - exclude_fields = cls._meta.exclude_fields - already_created_fields = {f.attname for f in cls._meta.local_fields} +from .registry import Registry, get_global_registry +from .utils import get_query +from graphene.utils.is_base_type import is_base_type +from graphene.utils.copy_fields import copy_fields +from graphene.utils.get_graphql_type import get_graphql_type +from graphene.utils.get_fields import get_fields +from graphene.utils.as_field import as_field +from graphene.generators import generate_objecttype + + +class SQLAlchemyObjectTypeMeta(type(ObjectType)): + def _construct_fields(cls, fields, options): + only_fields = cls._meta.only + exclude_fields = cls._meta.exclude inspected_model = sqlalchemyinspect(cls._meta.model) # Get all the columns for the relationships on the model for relationship in inspected_model.relationships: is_not_in_only = only_fields and relationship.key not in only_fields - is_already_created = relationship.key in already_created_fields + is_already_created = relationship.key in fields is_excluded = relationship.key in exclude_fields or is_already_created if is_not_in_only or is_excluded: # We skip this field if we specify only_fields and is not # in there. Or when we excldue this field in exclude_fields continue - converted_relationship = convert_sqlalchemy_relationship(relationship) - cls.add_to_class(relationship.key, converted_relationship) + converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) + if not converted_relationship: + continue + name = relationship.key + fields[name] = as_field(converted_relationship) for name, column in inspected_model.columns.items(): is_not_in_only = only_fields and name not in only_fields - is_already_created = name in already_created_fields + is_already_created = name in fields is_excluded = name in exclude_fields or is_already_created if is_not_in_only or is_excluded: # We skip this field if we specify only_fields and is not # in there. Or when we excldue this field in exclude_fields continue - converted_column = convert_sqlalchemy_column(column) - cls.add_to_class(name, converted_column) + converted_column = convert_sqlalchemy_column(column, options.registry) + if not converted_column: + continue + fields[name] = as_field(converted_column) - def construct(cls, *args, **kwargs): - cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs) - if not cls._meta.abstract: - if not cls._meta.model: - raise Exception( - 'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % - cls) - elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model): - raise Exception('Provided model in %s is not a SQLAlchemy model' % cls) + fields = copy_fields(Field, fields, parent=cls) + + return fields + + @staticmethod + def _create_objecttype(cls, name, bases, attrs): + # super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__ + super_new = type.__new__ + + # Also ensure initialization is only performed for subclasses of Model + # (excluding Model class itself). + if not is_base_type(bases, SQLAlchemyObjectTypeMeta): + return super_new(cls, name, bases, attrs) + + options = Options( + attrs.pop('Meta', None), + name=None, + description=None, + model=None, + fields=(), + exclude=(), + only=(), + interfaces=(), + registry=None + ) + + if not options.registry: + options.registry = get_global_registry() + assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry) + assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model) + + interfaces = tuple(options.interfaces) + fields = get_fields(ObjectType, attrs, bases, interfaces) + attrs = attrs_without_fields(attrs, fields) + cls = super_new(cls, name, bases, dict(attrs, _meta=options)) + + base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) + options.get_fields = partial(cls._construct_fields, fields, options) + options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces)) + + options.graphql_type = generate_objecttype(cls) + + if issubclass(cls, SQLAlchemyObjectType): + options.registry.register(cls) - cls.construct_fields() return cls -class InstanceObjectType(ObjectType): - - class Meta: - abstract = True - - def __init__(self, _root=None): - super(InstanceObjectType, self).__init__(_root=_root) - assert not self._root or isinstance(self._root, self._meta.model), ( - '{} received a non-compatible instance ({}) ' - 'when expecting {}'.format( - self.__class__.__name__, - self._root.__class__.__name__, - self._meta.model.__name__ - )) - - @property - def instance(self): - return self._root - - @instance.setter - def instance(self, value): - self._root = value +class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)): + is_type_of = None -class SQLAlchemyObjectType(six.with_metaclass( - SQLAlchemyObjectTypeMeta, InstanceObjectType)): +class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)): - class Meta: - abstract = True + @staticmethod + def _get_interface_options(meta): + return Options( + meta, + name=None, + description=None, + model=None, + graphql_type=None, + registry=False + ) + + @staticmethod + def _create_interface(cls, name, bases, attrs): + cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs) + if not cls._meta.registry: + cls._meta.registry = get_global_registry() + assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name) + return cls -class SQLAlchemyConnection(Connection): - pass - - -class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta): - pass - - -class NodeInstance(Node, InstanceObjectType): - - class Meta: - abstract = True - - -class SQLAlchemyNode(six.with_metaclass( - SQLAlchemyNodeMeta, NodeInstance)): - - class Meta: - abstract = True - - def to_global_id(self): - id_ = getattr(self.instance, self._meta.identifier) - return self.global_id(id_) - +class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)): @classmethod - def get_node(cls, id, info=None): + def get_node(cls, id, context, info): try: model = cls._meta.model - identifier = cls._meta.identifier - query = get_query(model, info) - instance = query.filter(getattr(model, identifier) == id).one() - return cls(instance) + query = get_query(model, context) + return query.get(id) except NoResultFound: return None + + @classmethod + def resolve_id(cls, root, args, context, info): + return root.__mapper__.primary_key_from_instance(root)[0] + + @classmethod + def resolve_type(cls, type_instance, context, info): + # We get the model from the _meta in the SQLAlchemy class/instance + model = type(type_instance) + graphene_type = cls._meta.registry.get_type_for_model(model) + if graphene_type: + return get_graphql_type(graphene_type) + + raise Exception("Type not found for model \"{}\"".format(model)) diff --git a/graphene-sqlalchemy/graphene_sqlalchemy/utils.py b/graphene-sqlalchemy/graphene_sqlalchemy/utils.py index 246a9d86..02183081 100644 --- a/graphene-sqlalchemy/graphene_sqlalchemy/utils.py +++ b/graphene-sqlalchemy/graphene_sqlalchemy/utils.py @@ -1,28 +1,14 @@ from sqlalchemy.ext.declarative.api import DeclarativeMeta -from sqlalchemy.orm.query import Query - -from graphene.utils import LazyList -def get_type_for_model(schema, model): - schema = schema - types = schema.types.values() - for _type in types: - type_model = hasattr(_type, '_meta') and getattr( - _type._meta, 'model', None) - if model == type_model: - return _type +def get_session(context): + return context.get('session') -def get_session(info): - schema = info.schema.graphene_schema - return schema.options.get('session') - - -def get_query(model, info): +def get_query(model, context): query = getattr(model, 'query', None) if not query: - session = get_session(info) + session = get_session(context) if not session: raise Exception('A query in the model Base or a session in the schema is required for querying.\n' 'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying') @@ -30,20 +16,5 @@ def get_query(model, info): return query -class WrappedQuery(LazyList): - - def __len__(self): - # Dont calculate the length using len(query), as this will - # evaluate the whole queryset and return it's length. - # Use .count() instead - return self._origin.count() - - -def maybe_query(value): - if isinstance(value, Query): - return WrappedQuery(value) - return value - - def is_mapped(obj): return isinstance(obj, DeclarativeMeta) diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index 3377a41c..802ee5f0 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -96,11 +96,10 @@ class IterableConnectionField(Field): @property def connection(self): from .node import Node - graphql_type = super(IterableConnectionField, self).type - if issubclass(graphql_type.graphene_type, Node): - connection_type = graphql_type.graphene_type.get_default_connection() + if issubclass(self._type, Node): + connection_type = self._type.get_default_connection() else: - connection_type = graphql_type.graphene_type + connection_type = self._type assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) return connection_type diff --git a/graphene/relay/node.py b/graphene/relay/node.py index 4089b4d5..99121232 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -84,9 +84,13 @@ class Node(six.with_metaclass(NodeMeta, Interface)): # return to_global_id(type, id) # raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) + @classmethod + def resolve_id(cls, root, args, context, info): + return getattr(root, 'id', None) + @classmethod def id_resolver(cls, root, args, context, info): - return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) + return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info)) @classmethod def get_node_from_global_id(cls, global_id, context, info):