mirror of
				https://github.com/graphql-python/graphene.git
				synced 2025-10-31 16:07:27 +03:00 
			
		
		
		
	First working version of graphene-sqlalchemy
This commit is contained in:
		
							parent
							
								
									79d7636ab6
								
							
						
					
					
						commit
						af4c63512c
					
				|  | @ -4,7 +4,6 @@ from django.utils.encoding import force_text | ||||||
| from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull | from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull | ||||||
| from graphene.types.json import JSONString | from graphene.types.json import JSONString | ||||||
| from graphene.types.datetime import DateTime | from graphene.types.datetime import DateTime | ||||||
| from graphene.types.json import JSONString |  | ||||||
| from graphene.utils.str_converters import to_const | from graphene.utils.str_converters import to_const | ||||||
| from graphene.relay import Node | from graphene.relay import Node | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,8 +1,8 @@ | ||||||
| from graphene.contrib.sqlalchemy.types import ( | from .types import ( | ||||||
|     SQLAlchemyObjectType, |     SQLAlchemyObjectType, | ||||||
|     SQLAlchemyNode |     SQLAlchemyNode | ||||||
| ) | ) | ||||||
| from graphene.contrib.sqlalchemy.fields import ( | from .fields import ( | ||||||
|     SQLAlchemyConnectionField |     SQLAlchemyConnectionField | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -3,9 +3,10 @@ from sqlalchemy import types | ||||||
| from sqlalchemy.orm import interfaces | from sqlalchemy.orm import interfaces | ||||||
| from sqlalchemy.dialects import postgresql | from sqlalchemy.dialects import postgresql | ||||||
| 
 | 
 | ||||||
| from graphene import Enum, ID, Boolean, Float, Int, String, List | from graphene import Enum, ID, Boolean, Float, Int, String, List, Field | ||||||
|  | from graphene.relay import Node | ||||||
| from graphene.types.json import JSONString | from graphene.types.json import JSONString | ||||||
| from .fields import ConnectionOrListField, SQLAlchemyModelField | from .fields import SQLAlchemyConnectionField | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     from sqlalchemy_utils.types.choice import ChoiceType |     from sqlalchemy_utils.types.choice import ChoiceType | ||||||
|  | @ -14,23 +15,27 @@ except ImportError: | ||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def convert_sqlalchemy_relationship(relationship): | def convert_sqlalchemy_relationship(relationship, registry): | ||||||
|     direction = relationship.direction |     direction = relationship.direction | ||||||
|     model = relationship.mapper.entity |     model = relationship.mapper.entity | ||||||
|     model_field = SQLAlchemyModelField(model, description=relationship.doc) |     _type = registry.get_type_for_model(model) | ||||||
|  |     if not _type: | ||||||
|  |         return None | ||||||
|     if direction == interfaces.MANYTOONE: |     if direction == interfaces.MANYTOONE: | ||||||
|         return model_field |         return Field(_type) | ||||||
|     elif (direction == interfaces.ONETOMANY or |     elif (direction == interfaces.ONETOMANY or | ||||||
|           direction == interfaces.MANYTOMANY): |           direction == interfaces.MANYTOMANY): | ||||||
|         return ConnectionOrListField(model_field) |         if issubclass(_type, Node): | ||||||
|  |             return SQLAlchemyConnectionField(_type) | ||||||
|  |         return List(_type) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def convert_sqlalchemy_column(column): | def convert_sqlalchemy_column(column, registry=None): | ||||||
|     return convert_sqlalchemy_type(getattr(column, 'type', None), column) |     return convert_sqlalchemy_type(getattr(column, 'type', None), column, registry) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @singledispatch | @singledispatch | ||||||
| def convert_sqlalchemy_type(type, column): | def convert_sqlalchemy_type(type, column, registry=None): | ||||||
|     raise Exception( |     raise Exception( | ||||||
|         "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) |         "Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__)) | ||||||
| 
 | 
 | ||||||
|  | @ -45,14 +50,14 @@ def convert_sqlalchemy_type(type, column): | ||||||
| @convert_sqlalchemy_type.register(types.Enum) | @convert_sqlalchemy_type.register(types.Enum) | ||||||
| @convert_sqlalchemy_type.register(postgresql.ENUM) | @convert_sqlalchemy_type.register(postgresql.ENUM) | ||||||
| @convert_sqlalchemy_type.register(postgresql.UUID) | @convert_sqlalchemy_type.register(postgresql.UUID) | ||||||
| def convert_column_to_string(type, column): | def convert_column_to_string(type, column, registry=None): | ||||||
|     return String(description=column.doc) |     return String(description=column.doc) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @convert_sqlalchemy_type.register(types.SmallInteger) | @convert_sqlalchemy_type.register(types.SmallInteger) | ||||||
| @convert_sqlalchemy_type.register(types.BigInteger) | @convert_sqlalchemy_type.register(types.BigInteger) | ||||||
| @convert_sqlalchemy_type.register(types.Integer) | @convert_sqlalchemy_type.register(types.Integer) | ||||||
| def convert_column_to_int_or_id(type, column): | def convert_column_to_int_or_id(type, column, registry=None): | ||||||
|     if column.primary_key: |     if column.primary_key: | ||||||
|         return ID(description=column.doc) |         return ID(description=column.doc) | ||||||
|     else: |     else: | ||||||
|  | @ -60,24 +65,24 @@ def convert_column_to_int_or_id(type, column): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @convert_sqlalchemy_type.register(types.Boolean) | @convert_sqlalchemy_type.register(types.Boolean) | ||||||
| def convert_column_to_boolean(type, column): | def convert_column_to_boolean(type, column, registry=None): | ||||||
|     return Boolean(description=column.doc) |     return Boolean(description=column.doc) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @convert_sqlalchemy_type.register(types.Float) | @convert_sqlalchemy_type.register(types.Float) | ||||||
| @convert_sqlalchemy_type.register(types.Numeric) | @convert_sqlalchemy_type.register(types.Numeric) | ||||||
| def convert_column_to_float(type, column): | def convert_column_to_float(type, column, registry=None): | ||||||
|     return Float(description=column.doc) |     return Float(description=column.doc) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @convert_sqlalchemy_type.register(ChoiceType) | @convert_sqlalchemy_type.register(ChoiceType) | ||||||
| def convert_column_to_enum(type, column): | def convert_column_to_enum(type, column, registry=None): | ||||||
|     name = '{}_{}'.format(column.table.name, column.name).upper() |     name = '{}_{}'.format(column.table.name, column.name).upper() | ||||||
|     return Enum(name, type.choices, description=column.doc) |     return Enum(name, type.choices, description=column.doc) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @convert_sqlalchemy_type.register(postgresql.ARRAY) | @convert_sqlalchemy_type.register(postgresql.ARRAY) | ||||||
| def convert_postgres_array_to_list(type, column): | def convert_postgres_array_to_list(type, column, registry=None): | ||||||
|     graphene_type = convert_sqlalchemy_type(column.type.item_type, column) |     graphene_type = convert_sqlalchemy_type(column.type.item_type, column) | ||||||
|     return List(graphene_type, description=column.doc) |     return List(graphene_type, description=column.doc) | ||||||
| 
 | 
 | ||||||
|  | @ -85,5 +90,5 @@ def convert_postgres_array_to_list(type, column): | ||||||
| @convert_sqlalchemy_type.register(postgresql.HSTORE) | @convert_sqlalchemy_type.register(postgresql.HSTORE) | ||||||
| @convert_sqlalchemy_type.register(postgresql.JSON) | @convert_sqlalchemy_type.register(postgresql.JSON) | ||||||
| @convert_sqlalchemy_type.register(postgresql.JSONB) | @convert_sqlalchemy_type.register(postgresql.JSONB) | ||||||
| def convert_json_to_string(type, column): | def convert_json_to_string(type, column, registry=None): | ||||||
|     return JSONString(description=column.doc) |     return JSONString(description=column.doc) | ||||||
|  |  | ||||||
|  | @ -1,69 +1,35 @@ | ||||||
| from ...core.exceptions import SkipField | from sqlalchemy.orm.query import Query | ||||||
| from ...core.fields import Field |  | ||||||
| from ...core.types.base import FieldType |  | ||||||
| from ...core.types.definitions import List |  | ||||||
| from ...relay import ConnectionField |  | ||||||
| from ...relay.utils import is_node |  | ||||||
| from .utils import get_query, get_type_for_model, maybe_query |  | ||||||
| 
 | 
 | ||||||
| 
 | from graphene.relay import ConnectionField | ||||||
| class DefaultQuery(object): | from graphql_relay.connection.arrayconnection import connection_from_list_slice | ||||||
|     pass | from .utils import get_query | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SQLAlchemyConnectionField(ConnectionField): | class SQLAlchemyConnectionField(ConnectionField): | ||||||
| 
 | 
 | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery) |  | ||||||
|         return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|     @property |     @property | ||||||
|     def model(self): |     def model(self): | ||||||
|         return self.type._meta.model |         return self.connection._meta.node._meta.model | ||||||
| 
 | 
 | ||||||
|     def from_list(self, connection_type, resolved, args, context,  info): |     def get_query(self, context): | ||||||
|         if resolved is DefaultQuery: |         return get_query(self.model, context) | ||||||
|             resolved = get_query(self.model, info) |  | ||||||
|         query = maybe_query(resolved) |  | ||||||
|         return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info) |  | ||||||
| 
 | 
 | ||||||
|  |     def default_resolver(self, root, args, context, info): | ||||||
|  |         return getattr(root, self.source or self.attname, self.get_query(context)) | ||||||
| 
 | 
 | ||||||
| class ConnectionOrListField(Field): |     @staticmethod | ||||||
| 
 |     def connection_resolver(resolver, connection, root, args, context, info): | ||||||
|     def internal_type(self, schema): |         iterable = resolver(root, args, context, info) | ||||||
|         model_field = self.type |         if isinstance(iterable, Query): | ||||||
|         field_object_type = model_field.get_object_type(schema) |             _len = iterable.count() | ||||||
|         if not field_object_type: |  | ||||||
|             raise SkipField() |  | ||||||
|         if is_node(field_object_type): |  | ||||||
|             field = SQLAlchemyConnectionField(field_object_type) |  | ||||||
|         else: |         else: | ||||||
|             field = Field(List(field_object_type)) |             _len = len(iterable) | ||||||
|         field.contribute_to_class(self.object_type, self.attname) |         return connection_from_list_slice( | ||||||
|         return schema.T(field) |             iterable, | ||||||
| 
 |             args, | ||||||
| 
 |             slice_start=0, | ||||||
| class SQLAlchemyModelField(FieldType): |             list_length=_len, | ||||||
| 
 |             list_slice_length=_len, | ||||||
|     def __init__(self, model, *args, **kwargs): |             connection_type=connection, | ||||||
|         self.model = model |             edge_type=connection.Edge, | ||||||
|         super(SQLAlchemyModelField, self).__init__(*args, **kwargs) |  | ||||||
| 
 |  | ||||||
|     def internal_type(self, schema): |  | ||||||
|         _type = self.get_object_type(schema) |  | ||||||
|         if not _type and self.parent._meta.only_fields: |  | ||||||
|             raise Exception( |  | ||||||
|                 "Table %r is not accessible by the schema. " |  | ||||||
|                 "You can either register the type manually " |  | ||||||
|                 "using @schema.register. " |  | ||||||
|                 "Or disable the field in %s" % ( |  | ||||||
|                     self.model, |  | ||||||
|                     self.parent, |  | ||||||
|         ) |         ) | ||||||
|             ) |  | ||||||
|         if not _type: |  | ||||||
|             raise SkipField() |  | ||||||
|         return schema.T(_type) |  | ||||||
| 
 |  | ||||||
|     def get_object_type(self, schema): |  | ||||||
|         return get_type_for_model(schema, self.model) |  | ||||||
|  |  | ||||||
|  | @ -1,24 +0,0 @@ | ||||||
| from ...core.classtypes.objecttype import ObjectTypeOptions |  | ||||||
| from ...relay.types import Node |  | ||||||
| from ...relay.utils import is_node |  | ||||||
| 
 |  | ||||||
| VALID_ATTRS = ('model', 'only_fields', 'exclude_fields', 'identifier') |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SQLAlchemyOptions(ObjectTypeOptions): |  | ||||||
| 
 |  | ||||||
|     def __init__(self, *args, **kwargs): |  | ||||||
|         super(SQLAlchemyOptions, self).__init__(*args, **kwargs) |  | ||||||
|         self.model = None |  | ||||||
|         self.identifier = "id" |  | ||||||
|         self.valid_attrs += VALID_ATTRS |  | ||||||
|         self.only_fields = None |  | ||||||
|         self.exclude_fields = [] |  | ||||||
|         self.filter_fields = None |  | ||||||
|         self.filter_order_by = None |  | ||||||
| 
 |  | ||||||
|     def contribute_to_class(self, cls, name): |  | ||||||
|         super(SQLAlchemyOptions, self).contribute_to_class(cls, name) |  | ||||||
|         if is_node(cls): |  | ||||||
|             self.exclude_fields = list(self.exclude_fields) + ['id'] |  | ||||||
|             self.interfaces.append(Node) |  | ||||||
							
								
								
									
										28
									
								
								graphene-sqlalchemy/graphene_sqlalchemy/registry.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								graphene-sqlalchemy/graphene_sqlalchemy/registry.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | ||||||
|  | class Registry(object): | ||||||
|  |     def __init__(self): | ||||||
|  |         self._registry = {} | ||||||
|  |         self._registry_models = {} | ||||||
|  | 
 | ||||||
|  |     def register(self, cls): | ||||||
|  |         from .types import SQLAlchemyObjectType | ||||||
|  |         assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__) | ||||||
|  |         assert cls._meta.registry == self, 'Registry for a Model have to match.' | ||||||
|  |         self._registry[cls._meta.model] = cls | ||||||
|  | 
 | ||||||
|  |     def get_type_for_model(self, model): | ||||||
|  |         return self._registry.get(model) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | registry = None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def get_global_registry(): | ||||||
|  |     global registry | ||||||
|  |     if not registry: | ||||||
|  |         registry = Registry() | ||||||
|  |     return registry | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def reset_global_registry(): | ||||||
|  |     global registry | ||||||
|  |     registry = None | ||||||
|  | @ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType | ||||||
| from sqlalchemy.dialects import postgresql | from sqlalchemy.dialects import postgresql | ||||||
| 
 | 
 | ||||||
| import graphene | import graphene | ||||||
| from graphene.core.types.custom_scalars import JSONString | from graphene.types.json import JSONString | ||||||
| from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column, | from ..converter import (convert_sqlalchemy_column, | ||||||
|                          convert_sqlalchemy_relationship) |                          convert_sqlalchemy_relationship) | ||||||
| from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField, | from ..fields import SQLAlchemyConnectionField | ||||||
|                                                 SQLAlchemyModelField) | from ..types import SQLAlchemyObjectType, SQLAlchemyNode | ||||||
|  | from ..registry import Registry | ||||||
| 
 | 
 | ||||||
| from .models import Article, Pet, Reporter | from .models import Article, Pet, Reporter | ||||||
| 
 | 
 | ||||||
|  | @ -100,30 +101,63 @@ def test_should_choice_convert_enum(): | ||||||
|     Table('translatedmodel', Base.metadata, column) |     Table('translatedmodel', Base.metadata, column) | ||||||
|     graphene_type = convert_sqlalchemy_column(column) |     graphene_type = convert_sqlalchemy_column(column) | ||||||
|     assert issubclass(graphene_type, graphene.Enum) |     assert issubclass(graphene_type, graphene.Enum) | ||||||
|     assert graphene_type._meta.type_name == 'TRANSLATEDMODEL_LANGUAGE' |     assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE' | ||||||
|     assert graphene_type._meta.description == 'Language' |     assert graphene_type._meta.graphql_type.description == 'Language' | ||||||
|     assert graphene_type.__enum__.__members__['es'].value == 'Spanish' |     assert graphene_type._meta.enum.__members__['es'].value == 'Spanish' | ||||||
|     assert graphene_type.__enum__.__members__['en'].value == 'English' |     assert graphene_type._meta.enum.__members__['en'].value == 'English' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_manytomany_convert_connectionorlist(): | def test_should_manytomany_convert_connectionorlist(): | ||||||
|     graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property) |     registry = Registry() | ||||||
|     assert isinstance(graphene_type, ConnectionOrListField) |     graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry) | ||||||
|     assert isinstance(graphene_type.type, SQLAlchemyModelField) |     assert not graphene_type | ||||||
|     assert graphene_type.type.model == Pet | 
 | ||||||
|  | 
 | ||||||
|  | def test_should_manytomany_convert_connectionorlist_list(): | ||||||
|  |     class A(SQLAlchemyObjectType): | ||||||
|  |         class Meta: | ||||||
|  |             model = Pet | ||||||
|  | 
 | ||||||
|  |     graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) | ||||||
|  |     assert isinstance(graphene_type, graphene.List) | ||||||
|  |     assert graphene_type.of_type == A._meta.graphql_type | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_should_manytomany_convert_connectionorlist_connection(): | ||||||
|  |     class A(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
|  |         class Meta: | ||||||
|  |             model = Pet | ||||||
|  | 
 | ||||||
|  |     graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) | ||||||
|  |     assert isinstance(graphene_type, SQLAlchemyConnectionField) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_manytoone_convert_connectionorlist(): | def test_should_manytoone_convert_connectionorlist(): | ||||||
|     field = convert_sqlalchemy_relationship(Article.reporter.property) |     registry = Registry() | ||||||
|     assert isinstance(field, SQLAlchemyModelField) |     graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry) | ||||||
|     assert field.model == Reporter |     assert not graphene_type | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_onetomany_convert_model(): | def test_should_manytoone_convert_connectionorlist_list(): | ||||||
|     graphene_type = convert_sqlalchemy_relationship(Reporter.articles.property) |     class A(SQLAlchemyObjectType): | ||||||
|     assert isinstance(graphene_type, ConnectionOrListField) |         class Meta: | ||||||
|     assert isinstance(graphene_type.type, SQLAlchemyModelField) |             model = Reporter | ||||||
|     assert graphene_type.type.model == Article | 
 | ||||||
|  |     graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) | ||||||
|  |     assert isinstance(graphene_type, graphene.Field) | ||||||
|  |     assert graphene_type.type == A._meta.graphql_type | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def test_should_manytoone_convert_connectionorlist_connection(): | ||||||
|  |     class A(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
|  |         class Meta: | ||||||
|  |             model = Reporter | ||||||
|  | 
 | ||||||
|  |     graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) | ||||||
|  |     assert isinstance(graphene_type, graphene.Field) | ||||||
|  |     assert graphene_type.type == A._meta.graphql_type | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_postgresql_uuid_convert(): | def test_should_postgresql_uuid_convert(): | ||||||
|  |  | ||||||
|  | @ -4,8 +4,8 @@ from sqlalchemy.orm import scoped_session, sessionmaker | ||||||
| 
 | 
 | ||||||
| import graphene | import graphene | ||||||
| from graphene import relay | from graphene import relay | ||||||
| from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField, | from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) | ||||||
|                                          SQLAlchemyNode, SQLAlchemyObjectType) | from ..fields import SQLAlchemyConnectionField | ||||||
| 
 | 
 | ||||||
| from .models import Article, Base, Editor, Reporter | from .models import Article, Base, Editor, Reporter | ||||||
| 
 | 
 | ||||||
|  | @ -52,7 +52,7 @@ def test_should_query_well(session): | ||||||
| 
 | 
 | ||||||
|     class Query(graphene.ObjectType): |     class Query(graphene.ObjectType): | ||||||
|         reporter = graphene.Field(ReporterType) |         reporter = graphene.Field(ReporterType) | ||||||
|         reporters = ReporterType.List() |         reporters = graphene.List(ReporterType) | ||||||
| 
 | 
 | ||||||
|         def resolve_reporter(self, *args, **kwargs): |         def resolve_reporter(self, *args, **kwargs): | ||||||
|             return session.query(Reporter).first() |             return session.query(Reporter).first() | ||||||
|  | @ -93,7 +93,7 @@ def test_should_query_well(session): | ||||||
| def test_should_node(session): | def test_should_node(session): | ||||||
|     setup_fixtures(session) |     setup_fixtures(session) | ||||||
| 
 | 
 | ||||||
|     class ReporterNode(SQLAlchemyNode): |     class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
| 
 | 
 | ||||||
|         class Meta: |         class Meta: | ||||||
|             model = Reporter |             model = Reporter | ||||||
|  | @ -105,7 +105,7 @@ def test_should_node(session): | ||||||
|         def resolve_articles(self, *args, **kwargs): |         def resolve_articles(self, *args, **kwargs): | ||||||
|             return [Article(headline='Hi!')] |             return [Article(headline='Hi!')] | ||||||
| 
 | 
 | ||||||
|     class ArticleNode(SQLAlchemyNode): |     class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
| 
 | 
 | ||||||
|         class Meta: |         class Meta: | ||||||
|             model = Article |             model = Article | ||||||
|  | @ -115,7 +115,7 @@ def test_should_node(session): | ||||||
|         #     return Article(id=1, headline='Article node') |         #     return Article(id=1, headline='Article node') | ||||||
| 
 | 
 | ||||||
|     class Query(graphene.ObjectType): |     class Query(graphene.ObjectType): | ||||||
|         node = relay.NodeField() |         node = SQLAlchemyNode.Field() | ||||||
|         reporter = graphene.Field(ReporterNode) |         reporter = graphene.Field(ReporterNode) | ||||||
|         article = graphene.Field(ArticleNode) |         article = graphene.Field(ArticleNode) | ||||||
|         all_articles = SQLAlchemyConnectionField(ArticleNode) |         all_articles = SQLAlchemyConnectionField(ArticleNode) | ||||||
|  | @ -185,8 +185,8 @@ def test_should_node(session): | ||||||
|             'headline': 'Hi!' |             'headline': 'Hi!' | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     schema = graphene.Schema(query=Query, session=session) |     schema = graphene.Schema(query=Query) | ||||||
|     result = schema.execute(query) |     result = schema.execute(query, context_value={'session': session}) | ||||||
|     assert not result.errors |     assert not result.errors | ||||||
|     assert result.data == expected |     assert result.data == expected | ||||||
| 
 | 
 | ||||||
|  | @ -194,14 +194,13 @@ def test_should_node(session): | ||||||
| def test_should_custom_identifier(session): | def test_should_custom_identifier(session): | ||||||
|     setup_fixtures(session) |     setup_fixtures(session) | ||||||
| 
 | 
 | ||||||
|     class EditorNode(SQLAlchemyNode): |     class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
| 
 | 
 | ||||||
|         class Meta: |         class Meta: | ||||||
|             model = Editor |             model = Editor | ||||||
|             identifier = "editor_id" |  | ||||||
| 
 | 
 | ||||||
|     class Query(graphene.ObjectType): |     class Query(graphene.ObjectType): | ||||||
|         node = relay.NodeField(EditorNode) |         node = SQLAlchemyNode.Field() | ||||||
|         all_editors = SQLAlchemyConnectionField(EditorNode) |         all_editors = SQLAlchemyConnectionField(EditorNode) | ||||||
| 
 | 
 | ||||||
|     query = ''' |     query = ''' | ||||||
|  | @ -215,9 +214,11 @@ def test_should_custom_identifier(session): | ||||||
|             } |             } | ||||||
|           }, |           }, | ||||||
|           node(id: "RWRpdG9yTm9kZTox") { |           node(id: "RWRpdG9yTm9kZTox") { | ||||||
|  |             ...on EditorNode { | ||||||
|               name |               name | ||||||
|             } |             } | ||||||
|           } |           } | ||||||
|  |         } | ||||||
|     ''' |     ''' | ||||||
|     expected = { |     expected = { | ||||||
|         'allEditors': { |         'allEditors': { | ||||||
|  | @ -233,7 +234,7 @@ def test_should_custom_identifier(session): | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     schema = graphene.Schema(query=Query, session=session) |     schema = graphene.Schema(query=Query) | ||||||
|     result = schema.execute(query) |     result = schema.execute(query, context_value={'session': session}) | ||||||
|     assert not result.errors |     assert not result.errors | ||||||
|     assert result.data == expected |     assert result.data == expected | ||||||
|  |  | ||||||
|  | @ -1,25 +1,24 @@ | ||||||
| from py.test import raises | from py.test import raises | ||||||
| 
 | 
 | ||||||
| from graphene.contrib.sqlalchemy import SQLAlchemyObjectType | from ..types import SQLAlchemyObjectType | ||||||
| from tests.utils import assert_equal_lists |  | ||||||
| 
 | 
 | ||||||
| from .models import Reporter | from .models import Reporter | ||||||
|  | from ..registry import Registry | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_raise_if_no_model(): | def test_should_raise_if_no_model(): | ||||||
|     with raises(Exception) as excinfo: |     with raises(Exception) as excinfo: | ||||||
|         class Character1(SQLAlchemyObjectType): |         class Character1(SQLAlchemyObjectType): | ||||||
|             pass |             pass | ||||||
|     assert 'model in the Meta' in str(excinfo.value) |     assert 'valid SQLAlchemy Model' in str(excinfo.value) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_raise_if_model_is_invalid(): | def test_should_raise_if_model_is_invalid(): | ||||||
|     with raises(Exception) as excinfo: |     with raises(Exception) as excinfo: | ||||||
|         class Character2(SQLAlchemyObjectType): |         class Character2(SQLAlchemyObjectType): | ||||||
| 
 |  | ||||||
|             class Meta: |             class Meta: | ||||||
|                 model = 1 |                 model = 1 | ||||||
|     assert 'not a SQLAlchemy model' in str(excinfo.value) |     assert 'valid SQLAlchemy Model' in str(excinfo.value) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_map_fields_correctly(): | def test_should_map_fields_correctly(): | ||||||
|  | @ -27,10 +26,9 @@ def test_should_map_fields_correctly(): | ||||||
| 
 | 
 | ||||||
|         class Meta: |         class Meta: | ||||||
|             model = Reporter |             model = Reporter | ||||||
|     assert_equal_lists( |             registry = Registry() | ||||||
|         ReporterType2._meta.fields_map.keys(), | 
 | ||||||
|         ['articles', 'first_name', 'last_name', 'email', 'pets', 'id'] |     assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email'] | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_should_map_only_few_fields(): | def test_should_map_only_few_fields(): | ||||||
|  | @ -38,8 +36,5 @@ def test_should_map_only_few_fields(): | ||||||
| 
 | 
 | ||||||
|         class Meta: |         class Meta: | ||||||
|             model = Reporter |             model = Reporter | ||||||
|             only_fields = ('id', 'email') |             only = ('id', 'email') | ||||||
|     assert_equal_lists( |     assert Reporter2._meta.get_fields().keys() == ['id', 'email'] | ||||||
|         Reporter2._meta.fields_map.keys(), |  | ||||||
|         ['id', 'email'] |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|  | @ -1,38 +1,42 @@ | ||||||
| from graphql.type import GraphQLObjectType | from graphql.type import GraphQLObjectType, GraphQLInterfaceType | ||||||
|  | from graphql.type.definition import GraphQLFieldDefinition | ||||||
|  | from graphql import GraphQLInt | ||||||
| from pytest import raises | from pytest import raises | ||||||
| 
 | 
 | ||||||
| from graphene import Schema | from graphene import Schema | ||||||
| from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode, | from ..types import (SQLAlchemyNode, SQLAlchemyObjectType) | ||||||
|                                                SQLAlchemyObjectType) | from ..registry import Registry | ||||||
| from graphene.core.fields import Field | 
 | ||||||
| from graphene.core.types.scalars import Int | from graphene import Field, Int | ||||||
| from graphene.relay.fields import GlobalIDField | # from tests.utils import assert_equal_lists | ||||||
| from tests.utils import assert_equal_lists |  | ||||||
| 
 | 
 | ||||||
| from .models import Article, Reporter | from .models import Article, Reporter | ||||||
| 
 | 
 | ||||||
| schema = Schema() | registry = Registry() | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| class Character(SQLAlchemyObjectType): | class Character(SQLAlchemyObjectType): | ||||||
|     '''Character description''' |     '''Character description''' | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = Reporter |         model = Reporter | ||||||
|  |         registry = registry | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @schema.register | class Human(SQLAlchemyNode, SQLAlchemyObjectType): | ||||||
| class Human(SQLAlchemyNode): |  | ||||||
|     '''Human description''' |     '''Human description''' | ||||||
| 
 | 
 | ||||||
|     pub_date = Int() |     pub_date = Int() | ||||||
| 
 | 
 | ||||||
|     class Meta: |     class Meta: | ||||||
|         model = Article |         model = Article | ||||||
|         exclude_fields = ('id', ) |         exclude = ('id', ) | ||||||
|  |         registry = registry | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_sqlalchemy_interface(): | def test_sqlalchemy_interface(): | ||||||
|     assert SQLAlchemyNode._meta.interface is True |     assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) | # @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1)) | ||||||
|  | @ -43,60 +47,31 @@ def test_sqlalchemy_interface(): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_objecttype_registered(): | def test_objecttype_registered(): | ||||||
|     object_type = schema.T(Character) |     object_type = Character._meta.graphql_type | ||||||
|     assert isinstance(object_type, GraphQLObjectType) |     assert isinstance(object_type, GraphQLObjectType) | ||||||
|     assert Character._meta.model == Reporter |     assert Character._meta.model == Reporter | ||||||
|     assert_equal_lists( |     assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email'] | ||||||
|         object_type.get_fields().keys(), |  | ||||||
|         ['articles', 'firstName', 'lastName', 'email', 'id'] |  | ||||||
|     ) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_sqlalchemynode_idfield(): | # def test_sqlalchemynode_idfield(): | ||||||
|     idfield = SQLAlchemyNode._meta.fields_map['id'] | #     idfield = SQLAlchemyNode._meta.fields_map['id'] | ||||||
|     assert isinstance(idfield, GlobalIDField) | #     assert isinstance(idfield, GlobalIDField) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_node_idfield(): | # def test_node_idfield(): | ||||||
|     idfield = Human._meta.fields_map['id'] | #     idfield = Human._meta.fields_map['id'] | ||||||
|     assert isinstance(idfield, GlobalIDField) | #     assert isinstance(idfield, GlobalIDField) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_node_replacedfield(): | def test_node_replacedfield(): | ||||||
|     idfield = Human._meta.fields_map['pub_date'] |     idfield = Human._meta.graphql_type.get_fields()['pubDate'] | ||||||
|     assert isinstance(idfield, Field) |     assert isinstance(idfield, GraphQLFieldDefinition) | ||||||
|     assert schema.T(idfield).type == schema.T(Int()) |     assert idfield.type == GraphQLInt | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_interface_objecttype_init_none(): |  | ||||||
|     h = Human() |  | ||||||
|     assert h._root is None |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_interface_objecttype_init_good(): |  | ||||||
|     instance = Article() |  | ||||||
|     h = Human(instance) |  | ||||||
|     assert h._root == instance |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_interface_objecttype_init_unexpected(): |  | ||||||
|     with raises(AssertionError) as excinfo: |  | ||||||
|         Human(object()) |  | ||||||
|     assert str(excinfo.value) == "Human received a non-compatible instance (object) when expecting Article" |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_object_type(): | def test_object_type(): | ||||||
|     object_type = schema.T(Human) |     object_type = Human._meta.graphql_type | ||||||
|     Human._meta.fields_map |     object_type.get_fields() | ||||||
|     assert Human._meta.interface is False |  | ||||||
|     assert isinstance(object_type, GraphQLObjectType) |     assert isinstance(object_type, GraphQLObjectType) | ||||||
|     assert_equal_lists( |     assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId'] | ||||||
|         object_type.get_fields().keys(), |     assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces() | ||||||
|         ['headline', 'id', 'reporter', 'reporterId', 'pubDate'] |  | ||||||
|     ) |  | ||||||
|     assert schema.T(SQLAlchemyNode) in object_type.get_interfaces() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_node_notinterface(): |  | ||||||
|     assert Human._meta.interface is False |  | ||||||
|     assert SQLAlchemyNode in Human._meta.interfaces |  | ||||||
|  |  | ||||||
|  | @ -5,13 +5,12 @@ from ..utils import get_session | ||||||
| 
 | 
 | ||||||
| def test_get_session(): | def test_get_session(): | ||||||
|     session = 'My SQLAlchemy session' |     session = 'My SQLAlchemy session' | ||||||
|     schema = Schema(session=session) |  | ||||||
| 
 | 
 | ||||||
|     class Query(ObjectType): |     class Query(ObjectType): | ||||||
|         x = String() |         x = String() | ||||||
| 
 | 
 | ||||||
|         def resolve_x(self, args, info): |         def resolve_x(self, args, context, info): | ||||||
|             return get_session(info) |             return get_session(context) | ||||||
| 
 | 
 | ||||||
|     query = ''' |     query = ''' | ||||||
|         query ReporterQuery { |         query ReporterQuery { | ||||||
|  | @ -19,7 +18,7 @@ def test_get_session(): | ||||||
|         } |         } | ||||||
|     ''' |     ''' | ||||||
| 
 | 
 | ||||||
|     schema = Schema(query=Query, session=session) |     schema = Schema(query=Query) | ||||||
|     result = schema.execute(query) |     result = schema.execute(query, context_value={'session': session}) | ||||||
|     assert not result.errors |     assert not result.errors | ||||||
|     assert result.data['x'] == session |     assert result.data['x'] == session | ||||||
|  |  | ||||||
|  | @ -1,125 +1,158 @@ | ||||||
| import inspect |  | ||||||
| 
 |  | ||||||
| import six | import six | ||||||
| from sqlalchemy.inspection import inspect as sqlalchemyinspect | from sqlalchemy.inspection import inspect as sqlalchemyinspect | ||||||
| from sqlalchemy.orm.exc import NoResultFound | from sqlalchemy.orm.exc import NoResultFound | ||||||
| 
 | 
 | ||||||
| from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta | from graphene import ObjectType | ||||||
| from ...relay.types import Connection, Node, NodeMeta | from graphene.relay import Node | ||||||
| from .converter import (convert_sqlalchemy_column, | from .converter import (convert_sqlalchemy_column, | ||||||
|                         convert_sqlalchemy_relationship) |                         convert_sqlalchemy_relationship) | ||||||
| from .options import SQLAlchemyOptions | from .utils import is_mapped | ||||||
| from .utils import get_query, is_mapped | 
 | ||||||
|  | from functools import partial | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SQLAlchemyObjectTypeMeta(ObjectTypeMeta): | from graphene import Field, Interface | ||||||
|     options_class = SQLAlchemyOptions | from graphene.types.options import Options | ||||||
|  | from graphene.types.objecttype import attrs_without_fields, get_interfaces | ||||||
| 
 | 
 | ||||||
|     def construct_fields(cls): | from .registry import Registry, get_global_registry | ||||||
|         only_fields = cls._meta.only_fields | from .utils import get_query | ||||||
|         exclude_fields = cls._meta.exclude_fields | from graphene.utils.is_base_type import is_base_type | ||||||
|         already_created_fields = {f.attname for f in cls._meta.local_fields} | from graphene.utils.copy_fields import copy_fields | ||||||
|  | from graphene.utils.get_graphql_type import get_graphql_type | ||||||
|  | from graphene.utils.get_fields import get_fields | ||||||
|  | from graphene.utils.as_field import as_field | ||||||
|  | from graphene.generators import generate_objecttype | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class SQLAlchemyObjectTypeMeta(type(ObjectType)): | ||||||
|  |     def _construct_fields(cls, fields, options): | ||||||
|  |         only_fields = cls._meta.only | ||||||
|  |         exclude_fields = cls._meta.exclude | ||||||
|         inspected_model = sqlalchemyinspect(cls._meta.model) |         inspected_model = sqlalchemyinspect(cls._meta.model) | ||||||
| 
 | 
 | ||||||
|         # Get all the columns for the relationships on the model |         # Get all the columns for the relationships on the model | ||||||
|         for relationship in inspected_model.relationships: |         for relationship in inspected_model.relationships: | ||||||
|             is_not_in_only = only_fields and relationship.key not in only_fields |             is_not_in_only = only_fields and relationship.key not in only_fields | ||||||
|             is_already_created = relationship.key in already_created_fields |             is_already_created = relationship.key in fields | ||||||
|             is_excluded = relationship.key in exclude_fields or is_already_created |             is_excluded = relationship.key in exclude_fields or is_already_created | ||||||
|             if is_not_in_only or is_excluded: |             if is_not_in_only or is_excluded: | ||||||
|                 # We skip this field if we specify only_fields and is not |                 # We skip this field if we specify only_fields and is not | ||||||
|                 # in there. Or when we excldue this field in exclude_fields |                 # in there. Or when we excldue this field in exclude_fields | ||||||
|                 continue |                 continue | ||||||
|             converted_relationship = convert_sqlalchemy_relationship(relationship) |             converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) | ||||||
|             cls.add_to_class(relationship.key, converted_relationship) |             if not converted_relationship: | ||||||
|  |                 continue | ||||||
|  |             name = relationship.key | ||||||
|  |             fields[name] = as_field(converted_relationship) | ||||||
| 
 | 
 | ||||||
|         for name, column in inspected_model.columns.items(): |         for name, column in inspected_model.columns.items(): | ||||||
|             is_not_in_only = only_fields and name not in only_fields |             is_not_in_only = only_fields and name not in only_fields | ||||||
|             is_already_created = name in already_created_fields |             is_already_created = name in fields | ||||||
|             is_excluded = name in exclude_fields or is_already_created |             is_excluded = name in exclude_fields or is_already_created | ||||||
|             if is_not_in_only or is_excluded: |             if is_not_in_only or is_excluded: | ||||||
|                 # We skip this field if we specify only_fields and is not |                 # We skip this field if we specify only_fields and is not | ||||||
|                 # in there. Or when we excldue this field in exclude_fields |                 # in there. Or when we excldue this field in exclude_fields | ||||||
|                 continue |                 continue | ||||||
|             converted_column = convert_sqlalchemy_column(column) |             converted_column = convert_sqlalchemy_column(column, options.registry) | ||||||
|             cls.add_to_class(name, converted_column) |             if not converted_column: | ||||||
|  |                 continue | ||||||
|  |             fields[name] = as_field(converted_column) | ||||||
| 
 | 
 | ||||||
|     def construct(cls, *args, **kwargs): |         fields = copy_fields(Field, fields, parent=cls) | ||||||
|         cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs) | 
 | ||||||
|         if not cls._meta.abstract: |         return fields | ||||||
|             if not cls._meta.model: | 
 | ||||||
|                 raise Exception( |     @staticmethod | ||||||
|                     'SQLAlchemy ObjectType %s must have a model in the Meta class attr' % |     def _create_objecttype(cls, name, bases, attrs): | ||||||
|                     cls) |         # super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__ | ||||||
|             elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model): |         super_new = type.__new__ | ||||||
|                 raise Exception('Provided model in %s is not a SQLAlchemy model' % cls) | 
 | ||||||
|  |         # Also ensure initialization is only performed for subclasses of Model | ||||||
|  |         # (excluding Model class itself). | ||||||
|  |         if not is_base_type(bases, SQLAlchemyObjectTypeMeta): | ||||||
|  |             return super_new(cls, name, bases, attrs) | ||||||
|  | 
 | ||||||
|  |         options = Options( | ||||||
|  |             attrs.pop('Meta', None), | ||||||
|  |             name=None, | ||||||
|  |             description=None, | ||||||
|  |             model=None, | ||||||
|  |             fields=(), | ||||||
|  |             exclude=(), | ||||||
|  |             only=(), | ||||||
|  |             interfaces=(), | ||||||
|  |             registry=None | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |         if not options.registry: | ||||||
|  |             options.registry = get_global_registry() | ||||||
|  |         assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry) | ||||||
|  |         assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model) | ||||||
|  | 
 | ||||||
|  |         interfaces = tuple(options.interfaces) | ||||||
|  |         fields = get_fields(ObjectType, attrs, bases, interfaces) | ||||||
|  |         attrs = attrs_without_fields(attrs, fields) | ||||||
|  |         cls = super_new(cls, name, bases, dict(attrs, _meta=options)) | ||||||
|  | 
 | ||||||
|  |         base_interfaces = tuple(b for b in bases if issubclass(b, Interface)) | ||||||
|  |         options.get_fields = partial(cls._construct_fields, fields, options) | ||||||
|  |         options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces)) | ||||||
|  | 
 | ||||||
|  |         options.graphql_type = generate_objecttype(cls) | ||||||
|  | 
 | ||||||
|  |         if issubclass(cls, SQLAlchemyObjectType): | ||||||
|  |             options.registry.register(cls) | ||||||
| 
 | 
 | ||||||
|             cls.construct_fields() |  | ||||||
|         return cls |         return cls | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class InstanceObjectType(ObjectType): | class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)): | ||||||
| 
 |     is_type_of = None | ||||||
|     class Meta: |  | ||||||
|         abstract = True |  | ||||||
| 
 |  | ||||||
|     def __init__(self, _root=None): |  | ||||||
|         super(InstanceObjectType, self).__init__(_root=_root) |  | ||||||
|         assert not self._root or isinstance(self._root, self._meta.model), ( |  | ||||||
|             '{} received a non-compatible instance ({}) ' |  | ||||||
|             'when expecting {}'.format( |  | ||||||
|                 self.__class__.__name__, |  | ||||||
|                 self._root.__class__.__name__, |  | ||||||
|                 self._meta.model.__name__ |  | ||||||
|             )) |  | ||||||
| 
 |  | ||||||
|     @property |  | ||||||
|     def instance(self): |  | ||||||
|         return self._root |  | ||||||
| 
 |  | ||||||
|     @instance.setter |  | ||||||
|     def instance(self, value): |  | ||||||
|         self._root = value |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SQLAlchemyObjectType(six.with_metaclass( | class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)): | ||||||
|         SQLAlchemyObjectTypeMeta, InstanceObjectType)): |  | ||||||
| 
 | 
 | ||||||
|     class Meta: |     @staticmethod | ||||||
|         abstract = True |     def _get_interface_options(meta): | ||||||
|  |         return Options( | ||||||
|  |             meta, | ||||||
|  |             name=None, | ||||||
|  |             description=None, | ||||||
|  |             model=None, | ||||||
|  |             graphql_type=None, | ||||||
|  |             registry=False | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     def _create_interface(cls, name, bases, attrs): | ||||||
|  |         cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs) | ||||||
|  |         if not cls._meta.registry: | ||||||
|  |             cls._meta.registry = get_global_registry() | ||||||
|  |         assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name) | ||||||
|  |         return cls | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class SQLAlchemyConnection(Connection): | class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)): | ||||||
|     pass |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta): |  | ||||||
|     pass |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class NodeInstance(Node, InstanceObjectType): |  | ||||||
| 
 |  | ||||||
|     class Meta: |  | ||||||
|         abstract = True |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class SQLAlchemyNode(six.with_metaclass( |  | ||||||
|         SQLAlchemyNodeMeta, NodeInstance)): |  | ||||||
| 
 |  | ||||||
|     class Meta: |  | ||||||
|         abstract = True |  | ||||||
| 
 |  | ||||||
|     def to_global_id(self): |  | ||||||
|         id_ = getattr(self.instance, self._meta.identifier) |  | ||||||
|         return self.global_id(id_) |  | ||||||
| 
 |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_node(cls, id, info=None): |     def get_node(cls, id, context, info): | ||||||
|         try: |         try: | ||||||
|             model = cls._meta.model |             model = cls._meta.model | ||||||
|             identifier = cls._meta.identifier |             query = get_query(model, context) | ||||||
|             query = get_query(model, info) |             return query.get(id) | ||||||
|             instance = query.filter(getattr(model, identifier) == id).one() |  | ||||||
|             return cls(instance) |  | ||||||
|         except NoResultFound: |         except NoResultFound: | ||||||
|             return None |             return None | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def resolve_id(cls, root, args, context, info): | ||||||
|  |         return root.__mapper__.primary_key_from_instance(root)[0] | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def resolve_type(cls, type_instance, context, info): | ||||||
|  |         # We get the model from the _meta in the SQLAlchemy class/instance | ||||||
|  |         model = type(type_instance) | ||||||
|  |         graphene_type = cls._meta.registry.get_type_for_model(model) | ||||||
|  |         if graphene_type: | ||||||
|  |             return get_graphql_type(graphene_type) | ||||||
|  | 
 | ||||||
|  |         raise Exception("Type not found for model \"{}\"".format(model)) | ||||||
|  |  | ||||||
|  | @ -1,28 +1,14 @@ | ||||||
| from sqlalchemy.ext.declarative.api import DeclarativeMeta | from sqlalchemy.ext.declarative.api import DeclarativeMeta | ||||||
| from sqlalchemy.orm.query import Query |  | ||||||
| 
 |  | ||||||
| from graphene.utils import LazyList |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_type_for_model(schema, model): | def get_session(context): | ||||||
|     schema = schema |     return context.get('session') | ||||||
|     types = schema.types.values() |  | ||||||
|     for _type in types: |  | ||||||
|         type_model = hasattr(_type, '_meta') and getattr( |  | ||||||
|             _type._meta, 'model', None) |  | ||||||
|         if model == type_model: |  | ||||||
|             return _type |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_session(info): | def get_query(model, context): | ||||||
|     schema = info.schema.graphene_schema |  | ||||||
|     return schema.options.get('session') |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def get_query(model, info): |  | ||||||
|     query = getattr(model, 'query', None) |     query = getattr(model, 'query', None) | ||||||
|     if not query: |     if not query: | ||||||
|         session = get_session(info) |         session = get_session(context) | ||||||
|         if not session: |         if not session: | ||||||
|             raise Exception('A query in the model Base or a session in the schema is required for querying.\n' |             raise Exception('A query in the model Base or a session in the schema is required for querying.\n' | ||||||
|                             'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying') |                             'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying') | ||||||
|  | @ -30,20 +16,5 @@ def get_query(model, info): | ||||||
|     return query |     return query | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class WrappedQuery(LazyList): |  | ||||||
| 
 |  | ||||||
|     def __len__(self): |  | ||||||
|         # Dont calculate the length using len(query), as this will |  | ||||||
|         # evaluate the whole queryset and return it's length. |  | ||||||
|         # Use .count() instead |  | ||||||
|         return self._origin.count() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def maybe_query(value): |  | ||||||
|     if isinstance(value, Query): |  | ||||||
|         return WrappedQuery(value) |  | ||||||
|     return value |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def is_mapped(obj): | def is_mapped(obj): | ||||||
|     return isinstance(obj, DeclarativeMeta) |     return isinstance(obj, DeclarativeMeta) | ||||||
|  |  | ||||||
|  | @ -96,11 +96,10 @@ class IterableConnectionField(Field): | ||||||
|     @property |     @property | ||||||
|     def connection(self): |     def connection(self): | ||||||
|         from .node import Node |         from .node import Node | ||||||
|         graphql_type = super(IterableConnectionField, self).type |         if issubclass(self._type, Node): | ||||||
|         if issubclass(graphql_type.graphene_type, Node): |             connection_type = self._type.get_default_connection() | ||||||
|             connection_type = graphql_type.graphene_type.get_default_connection() |  | ||||||
|         else: |         else: | ||||||
|             connection_type = graphql_type.graphene_type |             connection_type = self._type | ||||||
|         assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) |         assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self)) | ||||||
|         return connection_type |         return connection_type | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -84,9 +84,13 @@ class Node(six.with_metaclass(NodeMeta, Interface)): | ||||||
|         #     return to_global_id(type, id) |         #     return to_global_id(type, id) | ||||||
|         # raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) |         # raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__)) | ||||||
| 
 | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def resolve_id(cls, root, args, context, info): | ||||||
|  |         return getattr(root, 'id', None) | ||||||
|  | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def id_resolver(cls, root, args, context, info): |     def id_resolver(cls, root, args, context, info): | ||||||
|         return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) |         return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info)) | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_node_from_global_id(cls, global_id, context, info): |     def get_node_from_global_id(cls, global_id, context, info): | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user