mirror of
https://github.com/graphql-python/graphene.git
synced 2025-03-04 20:05:48 +03:00
First working version of graphene-sqlalchemy
This commit is contained in:
parent
79d7636ab6
commit
af4c63512c
|
@ -4,7 +4,6 @@ from django.utils.encoding import force_text
|
|||
from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull
|
||||
from graphene.types.json import JSONString
|
||||
from graphene.types.datetime import DateTime
|
||||
from graphene.types.json import JSONString
|
||||
from graphene.utils.str_converters import to_const
|
||||
from graphene.relay import Node
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from graphene.contrib.sqlalchemy.types import (
|
||||
from .types import (
|
||||
SQLAlchemyObjectType,
|
||||
SQLAlchemyNode
|
||||
)
|
||||
from graphene.contrib.sqlalchemy.fields import (
|
||||
from .fields import (
|
||||
SQLAlchemyConnectionField
|
||||
)
|
||||
|
||||
|
|
|
@ -3,9 +3,10 @@ from sqlalchemy import types
|
|||
from sqlalchemy.orm import interfaces
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from graphene import Enum, ID, Boolean, Float, Int, String, List
|
||||
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field
|
||||
from graphene.relay import Node
|
||||
from graphene.types.json import JSONString
|
||||
from .fields import ConnectionOrListField, SQLAlchemyModelField
|
||||
from .fields import SQLAlchemyConnectionField
|
||||
|
||||
try:
|
||||
from sqlalchemy_utils.types.choice import ChoiceType
|
||||
|
@ -14,23 +15,27 @@ except ImportError:
|
|||
pass
|
||||
|
||||
|
||||
def convert_sqlalchemy_relationship(relationship):
|
||||
def convert_sqlalchemy_relationship(relationship, registry):
|
||||
direction = relationship.direction
|
||||
model = relationship.mapper.entity
|
||||
model_field = SQLAlchemyModelField(model, description=relationship.doc)
|
||||
_type = registry.get_type_for_model(model)
|
||||
if not _type:
|
||||
return None
|
||||
if direction == interfaces.MANYTOONE:
|
||||
return model_field
|
||||
return Field(_type)
|
||||
elif (direction == interfaces.ONETOMANY or
|
||||
direction == interfaces.MANYTOMANY):
|
||||
return ConnectionOrListField(model_field)
|
||||
if issubclass(_type, Node):
|
||||
return SQLAlchemyConnectionField(_type)
|
||||
return List(_type)
|
||||
|
||||
|
||||
def convert_sqlalchemy_column(column):
|
||||
return convert_sqlalchemy_type(getattr(column, 'type', None), column)
|
||||
def convert_sqlalchemy_column(column, registry=None):
|
||||
return convert_sqlalchemy_type(getattr(column, 'type', None), column, registry)
|
||||
|
||||
|
||||
@singledispatch
|
||||
def convert_sqlalchemy_type(type, column):
|
||||
def convert_sqlalchemy_type(type, column, registry=None):
|
||||
raise Exception(
|
||||
"Don't know how to convert the SQLAlchemy field %s (%s)" % (column, column.__class__))
|
||||
|
||||
|
@ -45,14 +50,14 @@ def convert_sqlalchemy_type(type, column):
|
|||
@convert_sqlalchemy_type.register(types.Enum)
|
||||
@convert_sqlalchemy_type.register(postgresql.ENUM)
|
||||
@convert_sqlalchemy_type.register(postgresql.UUID)
|
||||
def convert_column_to_string(type, column):
|
||||
def convert_column_to_string(type, column, registry=None):
|
||||
return String(description=column.doc)
|
||||
|
||||
|
||||
@convert_sqlalchemy_type.register(types.SmallInteger)
|
||||
@convert_sqlalchemy_type.register(types.BigInteger)
|
||||
@convert_sqlalchemy_type.register(types.Integer)
|
||||
def convert_column_to_int_or_id(type, column):
|
||||
def convert_column_to_int_or_id(type, column, registry=None):
|
||||
if column.primary_key:
|
||||
return ID(description=column.doc)
|
||||
else:
|
||||
|
@ -60,24 +65,24 @@ def convert_column_to_int_or_id(type, column):
|
|||
|
||||
|
||||
@convert_sqlalchemy_type.register(types.Boolean)
|
||||
def convert_column_to_boolean(type, column):
|
||||
def convert_column_to_boolean(type, column, registry=None):
|
||||
return Boolean(description=column.doc)
|
||||
|
||||
|
||||
@convert_sqlalchemy_type.register(types.Float)
|
||||
@convert_sqlalchemy_type.register(types.Numeric)
|
||||
def convert_column_to_float(type, column):
|
||||
def convert_column_to_float(type, column, registry=None):
|
||||
return Float(description=column.doc)
|
||||
|
||||
|
||||
@convert_sqlalchemy_type.register(ChoiceType)
|
||||
def convert_column_to_enum(type, column):
|
||||
def convert_column_to_enum(type, column, registry=None):
|
||||
name = '{}_{}'.format(column.table.name, column.name).upper()
|
||||
return Enum(name, type.choices, description=column.doc)
|
||||
|
||||
|
||||
@convert_sqlalchemy_type.register(postgresql.ARRAY)
|
||||
def convert_postgres_array_to_list(type, column):
|
||||
def convert_postgres_array_to_list(type, column, registry=None):
|
||||
graphene_type = convert_sqlalchemy_type(column.type.item_type, column)
|
||||
return List(graphene_type, description=column.doc)
|
||||
|
||||
|
@ -85,5 +90,5 @@ def convert_postgres_array_to_list(type, column):
|
|||
@convert_sqlalchemy_type.register(postgresql.HSTORE)
|
||||
@convert_sqlalchemy_type.register(postgresql.JSON)
|
||||
@convert_sqlalchemy_type.register(postgresql.JSONB)
|
||||
def convert_json_to_string(type, column):
|
||||
def convert_json_to_string(type, column, registry=None):
|
||||
return JSONString(description=column.doc)
|
||||
|
|
|
@ -1,69 +1,35 @@
|
|||
from ...core.exceptions import SkipField
|
||||
from ...core.fields import Field
|
||||
from ...core.types.base import FieldType
|
||||
from ...core.types.definitions import List
|
||||
from ...relay import ConnectionField
|
||||
from ...relay.utils import is_node
|
||||
from .utils import get_query, get_type_for_model, maybe_query
|
||||
from sqlalchemy.orm.query import Query
|
||||
|
||||
|
||||
class DefaultQuery(object):
|
||||
pass
|
||||
from graphene.relay import ConnectionField
|
||||
from graphql_relay.connection.arrayconnection import connection_from_list_slice
|
||||
from .utils import get_query
|
||||
|
||||
|
||||
class SQLAlchemyConnectionField(ConnectionField):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs['default'] = kwargs.pop('default', lambda: DefaultQuery)
|
||||
return super(SQLAlchemyConnectionField, self).__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.type._meta.model
|
||||
return self.connection._meta.node._meta.model
|
||||
|
||||
def from_list(self, connection_type, resolved, args, context, info):
|
||||
if resolved is DefaultQuery:
|
||||
resolved = get_query(self.model, info)
|
||||
query = maybe_query(resolved)
|
||||
return super(SQLAlchemyConnectionField, self).from_list(connection_type, query, args, context, info)
|
||||
def get_query(self, context):
|
||||
return get_query(self.model, context)
|
||||
|
||||
def default_resolver(self, root, args, context, info):
|
||||
return getattr(root, self.source or self.attname, self.get_query(context))
|
||||
|
||||
class ConnectionOrListField(Field):
|
||||
|
||||
def internal_type(self, schema):
|
||||
model_field = self.type
|
||||
field_object_type = model_field.get_object_type(schema)
|
||||
if not field_object_type:
|
||||
raise SkipField()
|
||||
if is_node(field_object_type):
|
||||
field = SQLAlchemyConnectionField(field_object_type)
|
||||
@staticmethod
|
||||
def connection_resolver(resolver, connection, root, args, context, info):
|
||||
iterable = resolver(root, args, context, info)
|
||||
if isinstance(iterable, Query):
|
||||
_len = iterable.count()
|
||||
else:
|
||||
field = Field(List(field_object_type))
|
||||
field.contribute_to_class(self.object_type, self.attname)
|
||||
return schema.T(field)
|
||||
|
||||
|
||||
class SQLAlchemyModelField(FieldType):
|
||||
|
||||
def __init__(self, model, *args, **kwargs):
|
||||
self.model = model
|
||||
super(SQLAlchemyModelField, self).__init__(*args, **kwargs)
|
||||
|
||||
def internal_type(self, schema):
|
||||
_type = self.get_object_type(schema)
|
||||
if not _type and self.parent._meta.only_fields:
|
||||
raise Exception(
|
||||
"Table %r is not accessible by the schema. "
|
||||
"You can either register the type manually "
|
||||
"using @schema.register. "
|
||||
"Or disable the field in %s" % (
|
||||
self.model,
|
||||
self.parent,
|
||||
)
|
||||
)
|
||||
if not _type:
|
||||
raise SkipField()
|
||||
return schema.T(_type)
|
||||
|
||||
def get_object_type(self, schema):
|
||||
return get_type_for_model(schema, self.model)
|
||||
_len = len(iterable)
|
||||
return connection_from_list_slice(
|
||||
iterable,
|
||||
args,
|
||||
slice_start=0,
|
||||
list_length=_len,
|
||||
list_slice_length=_len,
|
||||
connection_type=connection,
|
||||
edge_type=connection.Edge,
|
||||
)
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
from ...core.classtypes.objecttype import ObjectTypeOptions
|
||||
from ...relay.types import Node
|
||||
from ...relay.utils import is_node
|
||||
|
||||
VALID_ATTRS = ('model', 'only_fields', 'exclude_fields', 'identifier')
|
||||
|
||||
|
||||
class SQLAlchemyOptions(ObjectTypeOptions):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SQLAlchemyOptions, self).__init__(*args, **kwargs)
|
||||
self.model = None
|
||||
self.identifier = "id"
|
||||
self.valid_attrs += VALID_ATTRS
|
||||
self.only_fields = None
|
||||
self.exclude_fields = []
|
||||
self.filter_fields = None
|
||||
self.filter_order_by = None
|
||||
|
||||
def contribute_to_class(self, cls, name):
|
||||
super(SQLAlchemyOptions, self).contribute_to_class(cls, name)
|
||||
if is_node(cls):
|
||||
self.exclude_fields = list(self.exclude_fields) + ['id']
|
||||
self.interfaces.append(Node)
|
28
graphene-sqlalchemy/graphene_sqlalchemy/registry.py
Normal file
28
graphene-sqlalchemy/graphene_sqlalchemy/registry.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
class Registry(object):
|
||||
def __init__(self):
|
||||
self._registry = {}
|
||||
self._registry_models = {}
|
||||
|
||||
def register(self, cls):
|
||||
from .types import SQLAlchemyObjectType
|
||||
assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__)
|
||||
assert cls._meta.registry == self, 'Registry for a Model have to match.'
|
||||
self._registry[cls._meta.model] = cls
|
||||
|
||||
def get_type_for_model(self, model):
|
||||
return self._registry.get(model)
|
||||
|
||||
|
||||
registry = None
|
||||
|
||||
|
||||
def get_global_registry():
|
||||
global registry
|
||||
if not registry:
|
||||
registry = Registry()
|
||||
return registry
|
||||
|
||||
|
||||
def reset_global_registry():
|
||||
global registry
|
||||
registry = None
|
|
@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType
|
|||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import graphene
|
||||
from graphene.core.types.custom_scalars import JSONString
|
||||
from graphene.contrib.sqlalchemy.converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from graphene.contrib.sqlalchemy.fields import (ConnectionOrListField,
|
||||
SQLAlchemyModelField)
|
||||
from graphene.types.json import JSONString
|
||||
from ..converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from ..fields import SQLAlchemyConnectionField
|
||||
from ..types import SQLAlchemyObjectType, SQLAlchemyNode
|
||||
from ..registry import Registry
|
||||
|
||||
from .models import Article, Pet, Reporter
|
||||
|
||||
|
@ -100,30 +101,63 @@ def test_should_choice_convert_enum():
|
|||
Table('translatedmodel', Base.metadata, column)
|
||||
graphene_type = convert_sqlalchemy_column(column)
|
||||
assert issubclass(graphene_type, graphene.Enum)
|
||||
assert graphene_type._meta.type_name == 'TRANSLATEDMODEL_LANGUAGE'
|
||||
assert graphene_type._meta.description == 'Language'
|
||||
assert graphene_type.__enum__.__members__['es'].value == 'Spanish'
|
||||
assert graphene_type.__enum__.__members__['en'].value == 'English'
|
||||
assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE'
|
||||
assert graphene_type._meta.graphql_type.description == 'Language'
|
||||
assert graphene_type._meta.enum.__members__['es'].value == 'Spanish'
|
||||
assert graphene_type._meta.enum.__members__['en'].value == 'English'
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist():
|
||||
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property)
|
||||
assert isinstance(graphene_type, ConnectionOrListField)
|
||||
assert isinstance(graphene_type.type, SQLAlchemyModelField)
|
||||
assert graphene_type.type.model == Pet
|
||||
registry = Registry()
|
||||
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
|
||||
assert not graphene_type
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist_list():
|
||||
class A(SQLAlchemyObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
|
||||
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
|
||||
assert isinstance(graphene_type, graphene.List)
|
||||
assert graphene_type.of_type == A._meta.graphql_type
|
||||
|
||||
|
||||
def test_should_manytomany_convert_connectionorlist_connection():
|
||||
class A(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
class Meta:
|
||||
model = Pet
|
||||
|
||||
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
|
||||
assert isinstance(graphene_type, SQLAlchemyConnectionField)
|
||||
|
||||
|
||||
|
||||
|
||||
def test_should_manytoone_convert_connectionorlist():
|
||||
field = convert_sqlalchemy_relationship(Article.reporter.property)
|
||||
assert isinstance(field, SQLAlchemyModelField)
|
||||
assert field.model == Reporter
|
||||
registry = Registry()
|
||||
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry)
|
||||
assert not graphene_type
|
||||
|
||||
|
||||
def test_should_onetomany_convert_model():
|
||||
graphene_type = convert_sqlalchemy_relationship(Reporter.articles.property)
|
||||
assert isinstance(graphene_type, ConnectionOrListField)
|
||||
assert isinstance(graphene_type.type, SQLAlchemyModelField)
|
||||
assert graphene_type.type.model == Article
|
||||
def test_should_manytoone_convert_connectionorlist_list():
|
||||
class A(SQLAlchemyObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
|
||||
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
|
||||
assert isinstance(graphene_type, graphene.Field)
|
||||
assert graphene_type.type == A._meta.graphql_type
|
||||
|
||||
|
||||
def test_should_manytoone_convert_connectionorlist_connection():
|
||||
class A(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
class Meta:
|
||||
model = Reporter
|
||||
|
||||
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
|
||||
assert isinstance(graphene_type, graphene.Field)
|
||||
assert graphene_type.type == A._meta.graphql_type
|
||||
|
||||
|
||||
def test_should_postgresql_uuid_convert():
|
||||
|
|
|
@ -4,8 +4,8 @@ from sqlalchemy.orm import scoped_session, sessionmaker
|
|||
|
||||
import graphene
|
||||
from graphene import relay
|
||||
from graphene.contrib.sqlalchemy import (SQLAlchemyConnectionField,
|
||||
SQLAlchemyNode, SQLAlchemyObjectType)
|
||||
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
|
||||
from ..fields import SQLAlchemyConnectionField
|
||||
|
||||
from .models import Article, Base, Editor, Reporter
|
||||
|
||||
|
@ -52,7 +52,7 @@ def test_should_query_well(session):
|
|||
|
||||
class Query(graphene.ObjectType):
|
||||
reporter = graphene.Field(ReporterType)
|
||||
reporters = ReporterType.List()
|
||||
reporters = graphene.List(ReporterType)
|
||||
|
||||
def resolve_reporter(self, *args, **kwargs):
|
||||
return session.query(Reporter).first()
|
||||
|
@ -93,7 +93,7 @@ def test_should_query_well(session):
|
|||
def test_should_node(session):
|
||||
setup_fixtures(session)
|
||||
|
||||
class ReporterNode(SQLAlchemyNode):
|
||||
class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
|
@ -105,7 +105,7 @@ def test_should_node(session):
|
|||
def resolve_articles(self, *args, **kwargs):
|
||||
return [Article(headline='Hi!')]
|
||||
|
||||
class ArticleNode(SQLAlchemyNode):
|
||||
class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
|
@ -115,7 +115,7 @@ def test_should_node(session):
|
|||
# return Article(id=1, headline='Article node')
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
node = relay.NodeField()
|
||||
node = SQLAlchemyNode.Field()
|
||||
reporter = graphene.Field(ReporterNode)
|
||||
article = graphene.Field(ArticleNode)
|
||||
all_articles = SQLAlchemyConnectionField(ArticleNode)
|
||||
|
@ -185,8 +185,8 @@ def test_should_node(session):
|
|||
'headline': 'Hi!'
|
||||
}
|
||||
}
|
||||
schema = graphene.Schema(query=Query, session=session)
|
||||
result = schema.execute(query)
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value={'session': session})
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
||||
|
@ -194,14 +194,13 @@ def test_should_node(session):
|
|||
def test_should_custom_identifier(session):
|
||||
setup_fixtures(session)
|
||||
|
||||
class EditorNode(SQLAlchemyNode):
|
||||
class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
|
||||
class Meta:
|
||||
model = Editor
|
||||
identifier = "editor_id"
|
||||
|
||||
class Query(graphene.ObjectType):
|
||||
node = relay.NodeField(EditorNode)
|
||||
node = SQLAlchemyNode.Field()
|
||||
all_editors = SQLAlchemyConnectionField(EditorNode)
|
||||
|
||||
query = '''
|
||||
|
@ -215,7 +214,9 @@ def test_should_custom_identifier(session):
|
|||
}
|
||||
},
|
||||
node(id: "RWRpdG9yTm9kZTox") {
|
||||
name
|
||||
...on EditorNode {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
@ -233,7 +234,7 @@ def test_should_custom_identifier(session):
|
|||
}
|
||||
}
|
||||
|
||||
schema = graphene.Schema(query=Query, session=session)
|
||||
result = schema.execute(query)
|
||||
schema = graphene.Schema(query=Query)
|
||||
result = schema.execute(query, context_value={'session': session})
|
||||
assert not result.errors
|
||||
assert result.data == expected
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
from py.test import raises
|
||||
|
||||
from graphene.contrib.sqlalchemy import SQLAlchemyObjectType
|
||||
from tests.utils import assert_equal_lists
|
||||
from ..types import SQLAlchemyObjectType
|
||||
|
||||
from .models import Reporter
|
||||
from ..registry import Registry
|
||||
|
||||
|
||||
def test_should_raise_if_no_model():
|
||||
with raises(Exception) as excinfo:
|
||||
class Character1(SQLAlchemyObjectType):
|
||||
pass
|
||||
assert 'model in the Meta' in str(excinfo.value)
|
||||
assert 'valid SQLAlchemy Model' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_raise_if_model_is_invalid():
|
||||
with raises(Exception) as excinfo:
|
||||
class Character2(SQLAlchemyObjectType):
|
||||
|
||||
class Meta:
|
||||
model = 1
|
||||
assert 'not a SQLAlchemy model' in str(excinfo.value)
|
||||
assert 'valid SQLAlchemy Model' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_should_map_fields_correctly():
|
||||
|
@ -27,10 +26,9 @@ def test_should_map_fields_correctly():
|
|||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
assert_equal_lists(
|
||||
ReporterType2._meta.fields_map.keys(),
|
||||
['articles', 'first_name', 'last_name', 'email', 'pets', 'id']
|
||||
)
|
||||
registry = Registry()
|
||||
|
||||
assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email']
|
||||
|
||||
|
||||
def test_should_map_only_few_fields():
|
||||
|
@ -38,8 +36,5 @@ def test_should_map_only_few_fields():
|
|||
|
||||
class Meta:
|
||||
model = Reporter
|
||||
only_fields = ('id', 'email')
|
||||
assert_equal_lists(
|
||||
Reporter2._meta.fields_map.keys(),
|
||||
['id', 'email']
|
||||
)
|
||||
only = ('id', 'email')
|
||||
assert Reporter2._meta.get_fields().keys() == ['id', 'email']
|
||||
|
|
|
@ -1,38 +1,42 @@
|
|||
from graphql.type import GraphQLObjectType
|
||||
from graphql.type import GraphQLObjectType, GraphQLInterfaceType
|
||||
from graphql.type.definition import GraphQLFieldDefinition
|
||||
from graphql import GraphQLInt
|
||||
from pytest import raises
|
||||
|
||||
from graphene import Schema
|
||||
from graphene.contrib.sqlalchemy.types import (SQLAlchemyNode,
|
||||
SQLAlchemyObjectType)
|
||||
from graphene.core.fields import Field
|
||||
from graphene.core.types.scalars import Int
|
||||
from graphene.relay.fields import GlobalIDField
|
||||
from tests.utils import assert_equal_lists
|
||||
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
|
||||
from ..registry import Registry
|
||||
|
||||
from graphene import Field, Int
|
||||
# from tests.utils import assert_equal_lists
|
||||
|
||||
from .models import Article, Reporter
|
||||
|
||||
schema = Schema()
|
||||
|
||||
registry = Registry()
|
||||
|
||||
class Character(SQLAlchemyObjectType):
|
||||
'''Character description'''
|
||||
class Meta:
|
||||
model = Reporter
|
||||
registry = registry
|
||||
|
||||
|
||||
@schema.register
|
||||
class Human(SQLAlchemyNode):
|
||||
class Human(SQLAlchemyNode, SQLAlchemyObjectType):
|
||||
'''Human description'''
|
||||
|
||||
pub_date = Int()
|
||||
|
||||
class Meta:
|
||||
model = Article
|
||||
exclude_fields = ('id', )
|
||||
exclude = ('id', )
|
||||
registry = registry
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def test_sqlalchemy_interface():
|
||||
assert SQLAlchemyNode._meta.interface is True
|
||||
assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType)
|
||||
|
||||
|
||||
# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
|
||||
|
@ -43,60 +47,31 @@ def test_sqlalchemy_interface():
|
|||
|
||||
|
||||
def test_objecttype_registered():
|
||||
object_type = schema.T(Character)
|
||||
object_type = Character._meta.graphql_type
|
||||
assert isinstance(object_type, GraphQLObjectType)
|
||||
assert Character._meta.model == Reporter
|
||||
assert_equal_lists(
|
||||
object_type.get_fields().keys(),
|
||||
['articles', 'firstName', 'lastName', 'email', 'id']
|
||||
)
|
||||
assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email']
|
||||
|
||||
|
||||
def test_sqlalchemynode_idfield():
|
||||
idfield = SQLAlchemyNode._meta.fields_map['id']
|
||||
assert isinstance(idfield, GlobalIDField)
|
||||
# def test_sqlalchemynode_idfield():
|
||||
# idfield = SQLAlchemyNode._meta.fields_map['id']
|
||||
# assert isinstance(idfield, GlobalIDField)
|
||||
|
||||
|
||||
def test_node_idfield():
|
||||
idfield = Human._meta.fields_map['id']
|
||||
assert isinstance(idfield, GlobalIDField)
|
||||
# def test_node_idfield():
|
||||
# idfield = Human._meta.fields_map['id']
|
||||
# assert isinstance(idfield, GlobalIDField)
|
||||
|
||||
|
||||
def test_node_replacedfield():
|
||||
idfield = Human._meta.fields_map['pub_date']
|
||||
assert isinstance(idfield, Field)
|
||||
assert schema.T(idfield).type == schema.T(Int())
|
||||
|
||||
|
||||
def test_interface_objecttype_init_none():
|
||||
h = Human()
|
||||
assert h._root is None
|
||||
|
||||
|
||||
def test_interface_objecttype_init_good():
|
||||
instance = Article()
|
||||
h = Human(instance)
|
||||
assert h._root == instance
|
||||
|
||||
|
||||
def test_interface_objecttype_init_unexpected():
|
||||
with raises(AssertionError) as excinfo:
|
||||
Human(object())
|
||||
assert str(excinfo.value) == "Human received a non-compatible instance (object) when expecting Article"
|
||||
idfield = Human._meta.graphql_type.get_fields()['pubDate']
|
||||
assert isinstance(idfield, GraphQLFieldDefinition)
|
||||
assert idfield.type == GraphQLInt
|
||||
|
||||
|
||||
def test_object_type():
|
||||
object_type = schema.T(Human)
|
||||
Human._meta.fields_map
|
||||
assert Human._meta.interface is False
|
||||
object_type = Human._meta.graphql_type
|
||||
object_type.get_fields()
|
||||
assert isinstance(object_type, GraphQLObjectType)
|
||||
assert_equal_lists(
|
||||
object_type.get_fields().keys(),
|
||||
['headline', 'id', 'reporter', 'reporterId', 'pubDate']
|
||||
)
|
||||
assert schema.T(SQLAlchemyNode) in object_type.get_interfaces()
|
||||
|
||||
|
||||
def test_node_notinterface():
|
||||
assert Human._meta.interface is False
|
||||
assert SQLAlchemyNode in Human._meta.interfaces
|
||||
assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId']
|
||||
assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces()
|
||||
|
|
|
@ -5,13 +5,12 @@ from ..utils import get_session
|
|||
|
||||
def test_get_session():
|
||||
session = 'My SQLAlchemy session'
|
||||
schema = Schema(session=session)
|
||||
|
||||
class Query(ObjectType):
|
||||
x = String()
|
||||
|
||||
def resolve_x(self, args, info):
|
||||
return get_session(info)
|
||||
def resolve_x(self, args, context, info):
|
||||
return get_session(context)
|
||||
|
||||
query = '''
|
||||
query ReporterQuery {
|
||||
|
@ -19,7 +18,7 @@ def test_get_session():
|
|||
}
|
||||
'''
|
||||
|
||||
schema = Schema(query=Query, session=session)
|
||||
result = schema.execute(query)
|
||||
schema = Schema(query=Query)
|
||||
result = schema.execute(query, context_value={'session': session})
|
||||
assert not result.errors
|
||||
assert result.data['x'] == session
|
||||
|
|
|
@ -1,125 +1,158 @@
|
|||
import inspect
|
||||
|
||||
import six
|
||||
from sqlalchemy.inspection import inspect as sqlalchemyinspect
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from ...core.classtypes.objecttype import ObjectType, ObjectTypeMeta
|
||||
from ...relay.types import Connection, Node, NodeMeta
|
||||
from graphene import ObjectType
|
||||
from graphene.relay import Node
|
||||
from .converter import (convert_sqlalchemy_column,
|
||||
convert_sqlalchemy_relationship)
|
||||
from .options import SQLAlchemyOptions
|
||||
from .utils import get_query, is_mapped
|
||||
from .utils import is_mapped
|
||||
|
||||
from functools import partial
|
||||
|
||||
|
||||
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
|
||||
options_class = SQLAlchemyOptions
|
||||
from graphene import Field, Interface
|
||||
from graphene.types.options import Options
|
||||
from graphene.types.objecttype import attrs_without_fields, get_interfaces
|
||||
|
||||
def construct_fields(cls):
|
||||
only_fields = cls._meta.only_fields
|
||||
exclude_fields = cls._meta.exclude_fields
|
||||
already_created_fields = {f.attname for f in cls._meta.local_fields}
|
||||
from .registry import Registry, get_global_registry
|
||||
from .utils import get_query
|
||||
from graphene.utils.is_base_type import is_base_type
|
||||
from graphene.utils.copy_fields import copy_fields
|
||||
from graphene.utils.get_graphql_type import get_graphql_type
|
||||
from graphene.utils.get_fields import get_fields
|
||||
from graphene.utils.as_field import as_field
|
||||
from graphene.generators import generate_objecttype
|
||||
|
||||
|
||||
class SQLAlchemyObjectTypeMeta(type(ObjectType)):
|
||||
def _construct_fields(cls, fields, options):
|
||||
only_fields = cls._meta.only
|
||||
exclude_fields = cls._meta.exclude
|
||||
inspected_model = sqlalchemyinspect(cls._meta.model)
|
||||
|
||||
# Get all the columns for the relationships on the model
|
||||
for relationship in inspected_model.relationships:
|
||||
is_not_in_only = only_fields and relationship.key not in only_fields
|
||||
is_already_created = relationship.key in already_created_fields
|
||||
is_already_created = relationship.key in fields
|
||||
is_excluded = relationship.key in exclude_fields or is_already_created
|
||||
if is_not_in_only or is_excluded:
|
||||
# We skip this field if we specify only_fields and is not
|
||||
# in there. Or when we excldue this field in exclude_fields
|
||||
continue
|
||||
converted_relationship = convert_sqlalchemy_relationship(relationship)
|
||||
cls.add_to_class(relationship.key, converted_relationship)
|
||||
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
|
||||
if not converted_relationship:
|
||||
continue
|
||||
name = relationship.key
|
||||
fields[name] = as_field(converted_relationship)
|
||||
|
||||
for name, column in inspected_model.columns.items():
|
||||
is_not_in_only = only_fields and name not in only_fields
|
||||
is_already_created = name in already_created_fields
|
||||
is_already_created = name in fields
|
||||
is_excluded = name in exclude_fields or is_already_created
|
||||
if is_not_in_only or is_excluded:
|
||||
# We skip this field if we specify only_fields and is not
|
||||
# in there. Or when we excldue this field in exclude_fields
|
||||
continue
|
||||
converted_column = convert_sqlalchemy_column(column)
|
||||
cls.add_to_class(name, converted_column)
|
||||
converted_column = convert_sqlalchemy_column(column, options.registry)
|
||||
if not converted_column:
|
||||
continue
|
||||
fields[name] = as_field(converted_column)
|
||||
|
||||
def construct(cls, *args, **kwargs):
|
||||
cls = super(SQLAlchemyObjectTypeMeta, cls).construct(*args, **kwargs)
|
||||
if not cls._meta.abstract:
|
||||
if not cls._meta.model:
|
||||
raise Exception(
|
||||
'SQLAlchemy ObjectType %s must have a model in the Meta class attr' %
|
||||
cls)
|
||||
elif not inspect.isclass(cls._meta.model) or not is_mapped(cls._meta.model):
|
||||
raise Exception('Provided model in %s is not a SQLAlchemy model' % cls)
|
||||
fields = copy_fields(Field, fields, parent=cls)
|
||||
|
||||
return fields
|
||||
|
||||
@staticmethod
|
||||
def _create_objecttype(cls, name, bases, attrs):
|
||||
# super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__
|
||||
super_new = type.__new__
|
||||
|
||||
# Also ensure initialization is only performed for subclasses of Model
|
||||
# (excluding Model class itself).
|
||||
if not is_base_type(bases, SQLAlchemyObjectTypeMeta):
|
||||
return super_new(cls, name, bases, attrs)
|
||||
|
||||
options = Options(
|
||||
attrs.pop('Meta', None),
|
||||
name=None,
|
||||
description=None,
|
||||
model=None,
|
||||
fields=(),
|
||||
exclude=(),
|
||||
only=(),
|
||||
interfaces=(),
|
||||
registry=None
|
||||
)
|
||||
|
||||
if not options.registry:
|
||||
options.registry = get_global_registry()
|
||||
assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry)
|
||||
assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model)
|
||||
|
||||
interfaces = tuple(options.interfaces)
|
||||
fields = get_fields(ObjectType, attrs, bases, interfaces)
|
||||
attrs = attrs_without_fields(attrs, fields)
|
||||
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
|
||||
|
||||
base_interfaces = tuple(b for b in bases if issubclass(b, Interface))
|
||||
options.get_fields = partial(cls._construct_fields, fields, options)
|
||||
options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces))
|
||||
|
||||
options.graphql_type = generate_objecttype(cls)
|
||||
|
||||
if issubclass(cls, SQLAlchemyObjectType):
|
||||
options.registry.register(cls)
|
||||
|
||||
cls.construct_fields()
|
||||
return cls
|
||||
|
||||
|
||||
class InstanceObjectType(ObjectType):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def __init__(self, _root=None):
|
||||
super(InstanceObjectType, self).__init__(_root=_root)
|
||||
assert not self._root or isinstance(self._root, self._meta.model), (
|
||||
'{} received a non-compatible instance ({}) '
|
||||
'when expecting {}'.format(
|
||||
self.__class__.__name__,
|
||||
self._root.__class__.__name__,
|
||||
self._meta.model.__name__
|
||||
))
|
||||
|
||||
@property
|
||||
def instance(self):
|
||||
return self._root
|
||||
|
||||
@instance.setter
|
||||
def instance(self, value):
|
||||
self._root = value
|
||||
class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)):
|
||||
is_type_of = None
|
||||
|
||||
|
||||
class SQLAlchemyObjectType(six.with_metaclass(
|
||||
SQLAlchemyObjectTypeMeta, InstanceObjectType)):
|
||||
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
@staticmethod
|
||||
def _get_interface_options(meta):
|
||||
return Options(
|
||||
meta,
|
||||
name=None,
|
||||
description=None,
|
||||
model=None,
|
||||
graphql_type=None,
|
||||
registry=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_interface(cls, name, bases, attrs):
|
||||
cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs)
|
||||
if not cls._meta.registry:
|
||||
cls._meta.registry = get_global_registry()
|
||||
assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name)
|
||||
return cls
|
||||
|
||||
|
||||
class SQLAlchemyConnection(Connection):
|
||||
pass
|
||||
|
||||
|
||||
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, NodeMeta):
|
||||
pass
|
||||
|
||||
|
||||
class NodeInstance(Node, InstanceObjectType):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class SQLAlchemyNode(six.with_metaclass(
|
||||
SQLAlchemyNodeMeta, NodeInstance)):
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def to_global_id(self):
|
||||
id_ = getattr(self.instance, self._meta.identifier)
|
||||
return self.global_id(id_)
|
||||
|
||||
class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
|
||||
@classmethod
|
||||
def get_node(cls, id, info=None):
|
||||
def get_node(cls, id, context, info):
|
||||
try:
|
||||
model = cls._meta.model
|
||||
identifier = cls._meta.identifier
|
||||
query = get_query(model, info)
|
||||
instance = query.filter(getattr(model, identifier) == id).one()
|
||||
return cls(instance)
|
||||
query = get_query(model, context)
|
||||
return query.get(id)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def resolve_id(cls, root, args, context, info):
|
||||
return root.__mapper__.primary_key_from_instance(root)[0]
|
||||
|
||||
@classmethod
|
||||
def resolve_type(cls, type_instance, context, info):
|
||||
# We get the model from the _meta in the SQLAlchemy class/instance
|
||||
model = type(type_instance)
|
||||
graphene_type = cls._meta.registry.get_type_for_model(model)
|
||||
if graphene_type:
|
||||
return get_graphql_type(graphene_type)
|
||||
|
||||
raise Exception("Type not found for model \"{}\"".format(model))
|
||||
|
|
|
@ -1,28 +1,14 @@
|
|||
from sqlalchemy.ext.declarative.api import DeclarativeMeta
|
||||
from sqlalchemy.orm.query import Query
|
||||
|
||||
from graphene.utils import LazyList
|
||||
|
||||
|
||||
def get_type_for_model(schema, model):
|
||||
schema = schema
|
||||
types = schema.types.values()
|
||||
for _type in types:
|
||||
type_model = hasattr(_type, '_meta') and getattr(
|
||||
_type._meta, 'model', None)
|
||||
if model == type_model:
|
||||
return _type
|
||||
def get_session(context):
|
||||
return context.get('session')
|
||||
|
||||
|
||||
def get_session(info):
|
||||
schema = info.schema.graphene_schema
|
||||
return schema.options.get('session')
|
||||
|
||||
|
||||
def get_query(model, info):
|
||||
def get_query(model, context):
|
||||
query = getattr(model, 'query', None)
|
||||
if not query:
|
||||
session = get_session(info)
|
||||
session = get_session(context)
|
||||
if not session:
|
||||
raise Exception('A query in the model Base or a session in the schema is required for querying.\n'
|
||||
'Read more http://graphene-python.org/docs/sqlalchemy/tips/#querying')
|
||||
|
@ -30,20 +16,5 @@ def get_query(model, info):
|
|||
return query
|
||||
|
||||
|
||||
class WrappedQuery(LazyList):
|
||||
|
||||
def __len__(self):
|
||||
# Dont calculate the length using len(query), as this will
|
||||
# evaluate the whole queryset and return it's length.
|
||||
# Use .count() instead
|
||||
return self._origin.count()
|
||||
|
||||
|
||||
def maybe_query(value):
|
||||
if isinstance(value, Query):
|
||||
return WrappedQuery(value)
|
||||
return value
|
||||
|
||||
|
||||
def is_mapped(obj):
|
||||
return isinstance(obj, DeclarativeMeta)
|
||||
|
|
|
@ -96,11 +96,10 @@ class IterableConnectionField(Field):
|
|||
@property
|
||||
def connection(self):
|
||||
from .node import Node
|
||||
graphql_type = super(IterableConnectionField, self).type
|
||||
if issubclass(graphql_type.graphene_type, Node):
|
||||
connection_type = graphql_type.graphene_type.get_default_connection()
|
||||
if issubclass(self._type, Node):
|
||||
connection_type = self._type.get_default_connection()
|
||||
else:
|
||||
connection_type = graphql_type.graphene_type
|
||||
connection_type = self._type
|
||||
assert issubclass(connection_type, Connection), '{} type have to be a subclass of Connection'.format(str(self))
|
||||
return connection_type
|
||||
|
||||
|
|
|
@ -84,9 +84,13 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
|
|||
# return to_global_id(type, id)
|
||||
# raise NotImplementedError("You need to implement {}.to_global_id".format(cls.__name__))
|
||||
|
||||
@classmethod
|
||||
def resolve_id(cls, root, args, context, info):
|
||||
return getattr(root, 'id', None)
|
||||
|
||||
@classmethod
|
||||
def id_resolver(cls, root, args, context, info):
|
||||
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
|
||||
return cls.to_global_id(info.parent_type.name, cls.resolve_id(root, args, context, info))
|
||||
|
||||
@classmethod
|
||||
def get_node_from_global_id(cls, global_id, context, info):
|
||||
|
|
Loading…
Reference in New Issue
Block a user