Updated SQLAlchemy code to work with latest version of graphene.

This commit is contained in:
Syrus Akbary 2016-08-14 15:42:27 -07:00
parent c414b3f688
commit f296a2a73f
10 changed files with 161 additions and 179 deletions

View File

@ -1,33 +1,35 @@
import graphene
from graphene import relay
from graphene_sqlalchemy import (SQLAlchemyConnectionField,
SQLAlchemyObjectType,
SQLAlchemyNode)
SQLAlchemyObjectType)
from models import Department as DepartmentModel
from models import Employee as EmployeeModel
from models import Role as RoleModel
class Department(SQLAlchemyNode, SQLAlchemyObjectType):
class Department(SQLAlchemyObjectType):
class Meta:
model = DepartmentModel
interfaces = (relay.Node, )
class Employee(SQLAlchemyNode, SQLAlchemyObjectType):
class Employee(SQLAlchemyObjectType):
class Meta:
model = EmployeeModel
interfaces = (relay.Node, )
class Role(SQLAlchemyNode, SQLAlchemyObjectType):
class Role(SQLAlchemyObjectType):
class Meta:
model = RoleModel
interfaces = (relay.Node, )
class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field()
node = relay.Node.Field()
all_employees = SQLAlchemyConnectionField(Employee)
all_roles = SQLAlchemyConnectionField(Role)
role = graphene.Field(Role)

View File

@ -1,10 +1,9 @@
from .types import (
SQLAlchemyObjectType,
SQLAlchemyNode
)
from .fields import (
SQLAlchemyConnectionField
)
__all__ = ['SQLAlchemyObjectType', 'SQLAlchemyNode',
__all__ = ['SQLAlchemyObjectType',
'SQLAlchemyConnectionField']

View File

@ -3,8 +3,8 @@ from sqlalchemy import types
from sqlalchemy.orm import interfaces
from sqlalchemy.dialects import postgresql
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field
from graphene.relay import Node
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic
from graphene.relay import is_node
from graphene.types.json import JSONString
from .fields import SQLAlchemyConnectionField
@ -18,16 +18,20 @@ except ImportError:
def convert_sqlalchemy_relationship(relationship, registry):
direction = relationship.direction
model = relationship.mapper.entity
_type = registry.get_type_for_model(model)
if not _type:
return None
if direction == interfaces.MANYTOONE:
return Field(_type)
elif (direction == interfaces.ONETOMANY or
direction == interfaces.MANYTOMANY):
if issubclass(_type, Node):
return SQLAlchemyConnectionField(_type)
return List(_type)
def dynamic_type():
_type = registry.get_type_for_model(model)
if not _type:
return None
if direction == interfaces.MANYTOONE:
return Field(_type)
elif (direction == interfaces.ONETOMANY or
direction == interfaces.MANYTOMANY):
if is_node(_type):
return SQLAlchemyConnectionField(_type)
return Field(List(_type))
return Dynamic(dynamic_type)
def convert_sqlalchemy_column(column, registry=None):

View File

@ -1,3 +1,4 @@
from functools import partial
from sqlalchemy.orm.query import Query
from graphene.relay import ConnectionField
@ -9,17 +10,13 @@ class SQLAlchemyConnectionField(ConnectionField):
@property
def model(self):
return self.connection._meta.node._meta.model
def get_query(self, context):
return get_query(self.model, context)
def default_resolver(self, root, args, context, info):
return getattr(root, self.source or self.attname, self.get_query(context))
return self.type._meta.node._meta.model
@staticmethod
def connection_resolver(resolver, connection, root, args, context, info):
def connection_resolver(resolver, connection, model, root, args, context, info):
iterable = resolver(root, args, context, info)
if iterable is None:
iterable = get_query(model, context)
if isinstance(iterable, Query):
_len = iterable.count()
else:
@ -33,3 +30,6 @@ class SQLAlchemyConnectionField(ConnectionField):
connection_type=connection,
edge_type=connection.Edge,
)
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.type, self.model)

View File

@ -7,10 +7,10 @@ class Registry(object):
from .types import SQLAlchemyObjectType
assert issubclass(cls, SQLAlchemyObjectType), 'Only SQLAlchemyObjectType can be registered, received "{}"'.format(cls.__name__)
assert cls._meta.registry == self, 'Registry for a Model have to match.'
assert self.get_type_for_model(cls._meta.model) in [None, cls], (
'SQLAlchemy model "{}" already associated with '
'another type "{}".'
).format(cls._meta.model, self._registry[cls._meta.model])
# assert self.get_type_for_model(cls._meta.model) in [None, cls], (
# 'SQLAlchemy model "{}" already associated with '
# 'another type "{}".'
# ).format(cls._meta.model, self._registry[cls._meta.model])
self._registry[cls._meta.model] = cls
def get_type_for_model(self, model):

View File

@ -5,11 +5,12 @@ from sqlalchemy_utils.types.choice import ChoiceType
from sqlalchemy.dialects import postgresql
import graphene
from graphene.relay import Node
from graphene.types.json import JSONString
from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship)
from ..fields import SQLAlchemyConnectionField
from ..types import SQLAlchemyObjectType, SQLAlchemyNode
from ..types import SQLAlchemyObjectType
from ..registry import Registry
from .models import Article, Pet, Reporter
@ -19,7 +20,7 @@ def assert_column_conversion(sqlalchemy_type, graphene_field, **kwargs):
column = Column(sqlalchemy_type, doc='Custom Help Text', **kwargs)
graphene_type = convert_sqlalchemy_column(column)
assert isinstance(graphene_type, graphene_field)
field = graphene_type.as_field()
field = graphene_type.Field()
assert field.description == 'Custom Help Text'
return field
@ -101,16 +102,17 @@ def test_should_choice_convert_enum():
Table('translatedmodel', Base.metadata, column)
graphene_type = convert_sqlalchemy_column(column)
assert issubclass(graphene_type, graphene.Enum)
assert graphene_type._meta.graphql_type.name == 'TRANSLATEDMODEL_LANGUAGE'
assert graphene_type._meta.graphql_type.description == 'Language'
assert graphene_type._meta.name == 'TRANSLATEDMODEL_LANGUAGE'
assert graphene_type._meta.description == 'Language'
assert graphene_type._meta.enum.__members__['es'].value == 'Spanish'
assert graphene_type._meta.enum.__members__['en'].value == 'English'
def test_should_manytomany_convert_connectionorlist():
registry = Registry()
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
assert not graphene_type
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
def test_should_manytomany_convert_connectionorlist_list():
@ -118,26 +120,30 @@ def test_should_manytomany_convert_connectionorlist_list():
class Meta:
model = Pet
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, graphene.List)
assert graphene_type.of_type == A._meta.graphql_type
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field)
assert isinstance(graphene_type.type, graphene.List)
assert graphene_type.type.of_type == A
def test_should_manytomany_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType):
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (Node, )
graphene_type = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(graphene_type, SQLAlchemyConnectionField)
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField)
def test_should_manytoone_convert_connectionorlist():
registry = Registry()
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, registry)
assert not graphene_type
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
def test_should_manytoone_convert_connectionorlist_list():
@ -145,19 +151,24 @@ def test_should_manytoone_convert_connectionorlist_list():
class Meta:
model = Reporter
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type
assert graphene_type.type == A
def test_should_manytoone_convert_connectionorlist_connection():
class A(SQLAlchemyNode, SQLAlchemyObjectType):
class A(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
graphene_type = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field)
assert graphene_type.type == A._meta.graphql_type
assert graphene_type.type == A
def test_should_postgresql_uuid_convert():

View File

@ -3,8 +3,8 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
import graphene
from graphene import relay
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
from graphene.relay import Node
from ..types import SQLAlchemyObjectType
from ..fields import SQLAlchemyConnectionField
from .models import Article, Base, Editor, Reporter
@ -93,10 +93,11 @@ def test_should_query_well(session):
def test_should_node(session):
setup_fixtures(session)
class ReporterNode(SQLAlchemyNode, SQLAlchemyObjectType):
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
@ -105,17 +106,18 @@ def test_should_node(session):
def resolve_articles(self, *args, **kwargs):
return [Article(headline='Hi!')]
class ArticleNode(SQLAlchemyNode, SQLAlchemyObjectType):
class ArticleNode(SQLAlchemyObjectType):
class Meta:
model = Article
interfaces = (Node, )
# @classmethod
# def get_node(cls, id, info):
# return Article(id=1, headline='Article node')
class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field()
node = Node.Field()
reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode)
@ -194,13 +196,14 @@ def test_should_node(session):
def test_should_custom_identifier(session):
setup_fixtures(session)
class EditorNode(SQLAlchemyNode, SQLAlchemyObjectType):
class EditorNode(SQLAlchemyObjectType):
class Meta:
model = Editor
interfaces = (Node, )
class Query(graphene.ObjectType):
node = SQLAlchemyNode.Field()
node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode)
query = '''

View File

@ -28,7 +28,7 @@ def test_should_map_fields_correctly():
model = Reporter
registry = Registry()
assert ReporterType2._meta.get_fields().keys() == ['id', 'firstName', 'lastName', 'email']
assert ReporterType2._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles']
def test_should_map_only_few_fields():
@ -36,5 +36,5 @@ def test_should_map_only_few_fields():
class Meta:
model = Reporter
only = ('id', 'email')
assert Reporter2._meta.get_fields().keys() == ['id', 'email']
only_fields = ('id', 'email')
assert Reporter2._meta.fields.keys() == ['id', 'email']

View File

@ -1,10 +1,10 @@
from graphql.type import GraphQLObjectType, GraphQLInterfaceType
from graphql.type.definition import GraphQLFieldDefinition
from graphql import GraphQLInt
from pytest import raises
from graphene import Schema
from ..types import (SQLAlchemyNode, SQLAlchemyObjectType)
from graphene import Schema, Interface, ObjectType
from graphene.relay import Node, is_node
from ..types import SQLAlchemyObjectType
from ..registry import Registry
from graphene import Field, Int
@ -14,6 +14,7 @@ from .models import Article, Reporter
registry = Registry()
class Character(SQLAlchemyObjectType):
'''Character description'''
class Meta:
@ -21,22 +22,21 @@ class Character(SQLAlchemyObjectType):
registry = registry
class Human(SQLAlchemyNode, SQLAlchemyObjectType):
class Human(SQLAlchemyObjectType):
'''Human description'''
pub_date = Int()
class Meta:
model = Article
exclude = ('id', )
exclude_fields = ('id', )
registry = registry
interfaces = (Node, )
def test_sqlalchemy_interface():
assert isinstance(SQLAlchemyNode._meta.graphql_type, GraphQLInterfaceType)
assert issubclass(Node, Interface)
assert issubclass(Node, Node)
# @patch('graphene.contrib.sqlalchemy.tests.models.Article.filter', return_value=Article(id=1))
@ -47,14 +47,13 @@ def test_sqlalchemy_interface():
def test_objecttype_registered():
object_type = Character._meta.graphql_type
assert isinstance(object_type, GraphQLObjectType)
assert issubclass(Character, ObjectType)
assert Character._meta.model == Reporter
assert object_type.get_fields().keys() == ['articles', 'id', 'firstName', 'lastName', 'email']
assert Character._meta.fields.keys() == ['id', 'first_name', 'last_name', 'email', 'pets', 'articles']
# def test_sqlalchemynode_idfield():
# idfield = SQLAlchemyNode._meta.fields_map['id']
# idfield = Node._meta.fields_map['id']
# assert isinstance(idfield, GlobalIDField)
@ -64,14 +63,12 @@ def test_objecttype_registered():
def test_node_replacedfield():
idfield = Human._meta.graphql_type.get_fields()['pubDate']
assert isinstance(idfield, GraphQLFieldDefinition)
assert idfield.type == GraphQLInt
idfield = Human._meta.fields['pub_date']
assert isinstance(idfield, Field)
assert idfield.type == Int
def test_object_type():
object_type = Human._meta.graphql_type
object_type.get_fields()
assert isinstance(object_type, GraphQLObjectType)
assert object_type.get_fields().keys() == ['id', 'pubDate', 'reporter', 'headline', 'reporterId']
assert SQLAlchemyNode._meta.graphql_type in object_type.get_interfaces()
assert issubclass(Human, ObjectType)
assert Human._meta.fields.keys() == ['id', 'pub_date', 'headline', 'reporter_id', 'reporter']
assert is_node(Human)

View File

@ -1,138 +1,112 @@
from collections import OrderedDict
import six
from sqlalchemy.inspection import inspect as sqlalchemyinspect
from sqlalchemy.orm.exc import NoResultFound
from graphene import ObjectType
from graphene.relay import Node
from graphene.relay import is_node
from .converter import (convert_sqlalchemy_column,
convert_sqlalchemy_relationship)
from .utils import is_mapped
from functools import partial
from graphene import Field, Interface
from graphene.types.objecttype import ObjectTypeMeta
from graphene.types.options import Options
from graphene.types.objecttype import attrs_without_fields, get_interfaces
from .registry import Registry, get_global_registry
from .utils import get_query
from graphene.utils.is_base_type import is_base_type
from graphene.utils.copy_fields import copy_fields
from graphene.utils.get_graphql_type import get_graphql_type
from graphene.utils.get_fields import get_fields
from graphene.utils.as_field import as_field
from graphene.generators import generate_objecttype
from graphene.types.utils import get_fields_in_type
from .utils import get_query
class SQLAlchemyObjectTypeMeta(type(ObjectType)):
def _construct_fields(cls, fields, options):
only_fields = cls._meta.only
exclude_fields = cls._meta.exclude
class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
def _construct_fields(cls, all_fields, options):
only_fields = cls._meta.only_fields
exclude_fields = cls._meta.exclude_fields
inspected_model = sqlalchemyinspect(cls._meta.model)
# Get all the columns for the relationships on the model
for relationship in inspected_model.relationships:
is_not_in_only = only_fields and relationship.key not in only_fields
is_already_created = relationship.key in fields
is_excluded = relationship.key in exclude_fields or is_already_created
if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
if not converted_relationship:
continue
name = relationship.key
fields[name] = as_field(converted_relationship)
fields = OrderedDict()
for name, column in inspected_model.columns.items():
is_not_in_only = only_fields and name not in only_fields
is_already_created = name in fields
is_already_created = name in all_fields
is_excluded = name in exclude_fields or is_already_created
if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields
continue
converted_column = convert_sqlalchemy_column(column, options.registry)
if not converted_column:
continue
fields[name] = as_field(converted_column)
fields[name] = converted_column
fields = copy_fields(Field, fields, parent=cls)
# Get all the columns for the relationships on the model
for relationship in inspected_model.relationships:
is_not_in_only = only_fields and relationship.key not in only_fields
is_already_created = relationship.key in all_fields
is_excluded = relationship.key in exclude_fields or is_already_created
if is_not_in_only or is_excluded:
# We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields
continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry)
name = relationship.key
fields[name] = converted_relationship
return fields
@staticmethod
def _create_objecttype(cls, name, bases, attrs):
# super_new = super(SQLAlchemyObjectTypeMeta, cls).__new__
super_new = type.__new__
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of Model
# (excluding Model class itself).
if not is_base_type(bases, SQLAlchemyObjectTypeMeta):
return super_new(cls, name, bases, attrs)
return type.__new__(cls, name, bases, attrs)
options = Options(
attrs.pop('Meta', None),
name=None,
description=None,
name=name,
description=attrs.pop('__doc__', None),
model=None,
fields=(),
exclude=(),
only=(),
fields=None,
only_fields=(),
exclude_fields=(),
id='id',
interfaces=(),
registry=None
)
if not options.registry:
options.registry = get_global_registry()
assert isinstance(options.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry, received "{}".'.format(name, options.registry)
assert is_mapped(options.model), 'You need to pass a valid SQLAlchemy Model in {}.Meta, received "{}".'.format(name, options.model)
assert isinstance(options.registry, Registry), (
'The attribute registry in {}.Meta needs to be an'
' instance of Registry, received "{}".'
).format(name, options.registry)
assert is_mapped(options.model), (
'You need to pass a valid SQLAlchemy Model in '
'{}.Meta, received "{}".'
).format(name, options.model)
interfaces = tuple(options.interfaces)
fields = get_fields(ObjectType, attrs, bases, interfaces)
attrs = attrs_without_fields(attrs, fields)
cls = super_new(cls, name, bases, dict(attrs, _meta=options))
base_interfaces = tuple(b for b in bases if issubclass(b, Interface))
options.get_fields = partial(cls._construct_fields, fields, options)
options.get_interfaces = tuple(get_interfaces(interfaces + base_interfaces))
cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options))
options.graphql_type = generate_objecttype(cls)
options.registry.register(cls)
if issubclass(cls, SQLAlchemyObjectType):
options.registry.register(cls)
options.sqlalchemy_fields = get_fields_in_type(
ObjectType,
cls._construct_fields(options.fields, options)
)
options.fields.update(options.sqlalchemy_fields)
return cls
class SQLAlchemyObjectType(six.with_metaclass(SQLAlchemyObjectTypeMeta, ObjectType)):
is_type_of = None
@classmethod
def is_type_of(cls, root, context, info):
if isinstance(root, cls):
return True
if not is_mapped(type(root)):
raise Exception((
'Received incompatible instance "{}".'
).format(root))
return type(root) == cls._meta.model
class SQLAlchemyNodeMeta(SQLAlchemyObjectTypeMeta, type(Node)):
@staticmethod
def _get_interface_options(meta):
return Options(
meta,
name=None,
description=None,
graphql_type=None,
registry=False
)
@staticmethod
def _create_interface(cls, name, bases, attrs):
cls = super(SQLAlchemyNodeMeta, cls)._create_interface(cls, name, bases, attrs)
if not cls._meta.registry:
cls._meta.registry = get_global_registry()
assert isinstance(cls._meta.registry, Registry), 'The attribute registry in {}.Meta needs to be an instance of Registry.'.format(name)
return cls
class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
@classmethod
def get_node(cls, id, context, info):
try:
@ -142,16 +116,8 @@ class SQLAlchemyNode(six.with_metaclass(SQLAlchemyNodeMeta, Node)):
except NoResultFound:
return None
@classmethod
def resolve_id(cls, root, args, context, info):
return root.__mapper__.primary_key_from_instance(root)[0]
@classmethod
def resolve_type(cls, type_instance, context, info):
# We get the model from the _meta in the SQLAlchemy class/instance
model = type(type_instance)
graphene_type = cls._meta.registry.get_type_for_model(model)
if graphene_type:
return get_graphql_type(graphene_type)
raise Exception("Type not found for model \"{}\"".format(model))
def resolve_id(root, args, context, info):
graphene_type = info.parent_type.graphene_type
if is_node(graphene_type):
return root.__mapper__.primary_key_from_instance(root)[0]
return getattr(root, graphene_type._meta.id, None)