This commit is contained in:
Markus Padourek 2016-09-17 19:48:45 +00:00 committed by GitHub
commit 55bd686be0
20 changed files with 463 additions and 170 deletions

1
.gitignore vendored
View File

@ -75,6 +75,7 @@ target/
# PyCharm # PyCharm
.idea .idea
*.iml
# Databases # Databases
*.sqlite3 *.sqlite3

View File

@ -16,6 +16,8 @@ class Ship(graphene.ObjectType):
def get_node(cls, id, context, info): def get_node(cls, id, context, info):
return get_ship(id) return get_ship(id)
ShipConnection = relay.Connection.for_type(Ship)
class Faction(graphene.ObjectType): class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
@ -24,7 +26,7 @@ class Faction(graphene.ObjectType):
interfaces = (relay.Node, ) interfaces = (relay.Node, )
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.') ships = relay.ConnectionField(ShipConnection, description='The ships used by the faction.')
@resolve_only_args @resolve_only_args
def resolve_ships(self, **args): def resolve_ships(self, **args):

View File

@ -54,13 +54,13 @@ type Ship implements Node {
} }
type ShipConnection { type ShipConnection {
pageInfo: PageInfo!
edges: [ShipEdge] edges: [ShipEdge]
pageInfo: PageInfo!
} }
type ShipEdge { type ShipEdge {
node: Ship
cursor: String! cursor: String!
node: Ship
} }
''' '''

View File

@ -1,7 +1,7 @@
from django.db import models from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from graphene import Enum, List, ID, Boolean, Float, Int, String, Field, NonNull, Field, Dynamic from graphene import Enum, List, ID, Boolean, Float, Int, String, NonNull, Field, Dynamic
from graphene.types.json import JSONString from graphene.types.json import JSONString
from graphene.types.datetime import DateTime from graphene.types.datetime import DateTime
from graphene.utils.str_converters import to_const from graphene.utils.str_converters import to_const

View File

@ -30,8 +30,10 @@ class Role(SQLAlchemyObjectType):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = relay.Node.Field() node = relay.Node.Field()
all_employees = SQLAlchemyConnectionField(Employee) employee_connection = relay.Connection.for_type(Employee)
all_roles = SQLAlchemyConnectionField(Role) role_connection = relay.Connection.for_type(Role)
all_employees = SQLAlchemyConnectionField(employee_connection)
all_roles = SQLAlchemyConnectionField(role_connection)
role = graphene.Field(Role) role = graphene.Field(Role)

View File

@ -4,7 +4,7 @@ from sqlalchemy.orm import interfaces
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic from graphene import Enum, ID, Boolean, Float, Int, String, List, Field, Dynamic
from graphene.relay import is_node from graphene.relay import is_node, Connection
from graphene.types.json import JSONString from graphene.types.json import JSONString
from .fields import SQLAlchemyConnectionField from .fields import SQLAlchemyConnectionField
@ -18,9 +18,10 @@ except ImportError:
pass pass
def convert_sqlalchemy_relationship(relationship, registry): def convert_sqlalchemy_relationship(relationship, registry, connections, type_name):
direction = relationship.direction direction = relationship.direction
model = relationship.mapper.entity model = relationship.mapper.entity
print(registry)
def dynamic_type(): def dynamic_type():
_type = registry.get_type_for_model(model) _type = registry.get_type_for_model(model)
@ -31,7 +32,13 @@ def convert_sqlalchemy_relationship(relationship, registry):
elif (direction == interfaces.ONETOMANY or elif (direction == interfaces.ONETOMANY or
direction == interfaces.MANYTOMANY): direction == interfaces.MANYTOMANY):
if is_node(_type): if is_node(_type):
return SQLAlchemyConnectionField(_type) try:
connection_type = connections[relationship.key]
except KeyError:
print(_type)
raise KeyError("No Connection provided for relationship {} on type {}. Specify it in its Meta "
"class on the 'connections' dict.".format(relationship.key, type_name))
return SQLAlchemyConnectionField(connection_type)
return Field(List(_type)) return Field(List(_type))
return Dynamic(dynamic_type) return Dynamic(dynamic_type)

View File

@ -6,7 +6,7 @@ from sqlalchemy_utils import ChoiceType, ScalarListType
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node, Connection
from graphene.types.json import JSONString from graphene.types.json import JSONString
from ..converter import (convert_sqlalchemy_column, from ..converter import (convert_sqlalchemy_column,
convert_sqlalchemy_composite, convert_sqlalchemy_composite,
@ -129,7 +129,7 @@ def test_should_scalar_list_convert_list():
def test_should_manytomany_convert_connectionorlist(): def test_should_manytomany_convert_connectionorlist():
registry = Registry() registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, registry, {}, '')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type() assert not dynamic_field.get_type()
@ -139,7 +139,7 @@ def test_should_manytomany_convert_connectionorlist_list():
class Meta: class Meta:
model = Pet model = Pet
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -153,14 +153,33 @@ def test_should_manytomany_convert_connectionorlist_connection():
model = Pet model = Pet
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry) connections = {'pets': Connection.for_type(A)}
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField) assert isinstance(dynamic_field.get_type(), SQLAlchemyConnectionField)
assert issubclass(dynamic_field.get_type().type, Connection)
def test_should_rais_when_no_connections_is_provided_for_manyto_many():
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
interfaces = (Node, )
connections = {}
with raises(KeyError) as ctx:
dynamic_field = convert_sqlalchemy_relationship(Reporter.pets.property, A._meta.registry, connections, 'A')
dynamic_field.get_type()
assert str(ctx.value) == ('\"No Connection provided for relationship pets on type A. Specify it in its Meta '
'class on the \'connections\' dict.\"')
def test_should_manytoone_convert_connectionorlist(): def test_should_manytoone_convert_connectionorlist():
registry = Registry() registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, registry, {}, '')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type() assert not dynamic_field.get_type()
@ -170,7 +189,7 @@ def test_should_manytoone_convert_connectionorlist_list():
class Meta: class Meta:
model = Reporter model = Reporter
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -183,7 +202,7 @@ def test_should_manytoone_convert_connectionorlist_connection():
model = Reporter model = Reporter
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Article.reporter.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)
@ -196,7 +215,7 @@ def test_should_onetoone_convert_field():
model = Article model = Article
interfaces = (Node, ) interfaces = (Node, )
dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry) dynamic_field = convert_sqlalchemy_relationship(Reporter.favorite_article.property, A._meta.registry, {}, 'A')
assert isinstance(dynamic_field, graphene.Dynamic) assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type() graphene_type = dynamic_field.get_type()
assert isinstance(graphene_type, graphene.Field) assert isinstance(graphene_type, graphene.Field)

View File

@ -3,11 +3,11 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
import graphene import graphene
from graphene.relay import Node from graphene.relay import Node, Connection
from ..types import SQLAlchemyObjectType from ..types import SQLAlchemyObjectType
from ..fields import SQLAlchemyConnectionField from ..fields import SQLAlchemyConnectionField
from .models import Article, Base, Editor, Reporter from .models import Article, Base, Editor, Reporter, Pet
db = create_engine('sqlite:///test_sqlalchemy.sqlite3') db = create_engine('sqlite:///test_sqlalchemy.sqlite3')
@ -46,10 +46,34 @@ def setup_fixtures(session):
def test_should_query_well(session): def test_should_query_well(session):
setup_fixtures(session) setup_fixtures(session)
class ArticleType(SQLAlchemyObjectType):
class Meta:
model = Article
ArticleTypeConnection = Connection.for_type(ArticleType)
ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterType))
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
connections = {
'reporters': ReporterNodeConnection,
'articles': ArticleTypeConnection,
}
interfaces = (Node, )
AConnection = Connection.for_type(A)
class ReporterType(SQLAlchemyObjectType): class ReporterType(SQLAlchemyObjectType):
class Meta: class Meta:
model = Reporter model = Reporter
connections = {
'pets': AConnection,
'articles': AConnection,
}
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
reporter = graphene.Field(ReporterType) reporter = graphene.Field(ReporterType)
@ -94,16 +118,6 @@ def test_should_query_well(session):
def test_should_node(session): def test_should_node(session):
setup_fixtures(session) setup_fixtures(session)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class ArticleNode(SQLAlchemyObjectType): class ArticleNode(SQLAlchemyObjectType):
class Meta: class Meta:
@ -114,11 +128,39 @@ def test_should_node(session):
# def get_node(cls, id, info): # def get_node(cls, id, info):
# return Article(id=1, headline='Article node') # return Article(id=1, headline='Article node')
ArticleNodeConnection = Connection.for_type(ArticleNode)
ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterNode))
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
connections = {
'reporters': ReporterNodeConnection
}
interfaces = (Node, )
AConnection = Connection.for_type(A)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
connections = {
'articles': ArticleNodeConnection,
'pets': AConnection,
}
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = Node.Field() node = Node.Field()
reporter = graphene.Field(ReporterNode) reporter = graphene.Field(ReporterNode)
article = graphene.Field(ArticleNode) article = graphene.Field(ArticleNode)
all_articles = SQLAlchemyConnectionField(ArticleNode) all_articles = SQLAlchemyConnectionField(ArticleNodeConnection)
def resolve_reporter(self, *args, **kwargs): def resolve_reporter(self, *args, **kwargs):
return session.query(Reporter).first() return session.query(Reporter).first()
@ -202,7 +244,8 @@ def test_should_custom_identifier(session):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
node = Node.Field() node = Node.Field()
all_editors = SQLAlchemyConnectionField(EditorNode) EditorNodeConnection = Connection.for_type(EditorNode)
all_editors = SQLAlchemyConnectionField(EditorNodeConnection)
query = ''' query = '''
query EditorQuery { query EditorQuery {
@ -250,23 +293,40 @@ def test_should_mutate_well(session):
model = Editor model = Editor
interfaces = (Node, ) interfaces = (Node, )
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class ArticleNode(SQLAlchemyObjectType): class ArticleNode(SQLAlchemyObjectType):
class Meta: class Meta:
model = Article model = Article
interfaces = (Node, ) interfaces = (Node, )
ArticleNodeConnection = Connection.for_type(ArticleNode)
ReporterNodeConnection = graphene.Dynamic(lambda: Connection.for_type(ReporterNode))
class A(SQLAlchemyObjectType):
class Meta:
model = Pet
connections = {
'reporters': ReporterNodeConnection
}
interfaces = (Node, )
AConnection = Connection.for_type(A)
class ReporterNode(SQLAlchemyObjectType):
class Meta:
model = Reporter
connections = {
'articles': ArticleNodeConnection,
'pets': AConnection,
}
interfaces = (Node, )
@classmethod
def get_node(cls, id, info):
return Reporter(id=2, first_name='Cookie Monster')
class CreateArticle(graphene.Mutation): class CreateArticle(graphene.Mutation):
class Input: class Input:
headline = graphene.String() headline = graphene.String()

View File

@ -18,12 +18,13 @@ from graphene.types.utils import yank_fields_from_attrs, merge
from .utils import get_query from .utils import get_query
def construct_fields(options): def construct_fields(options, type_name):
only_fields = options.only_fields only_fields = options.only_fields
exclude_fields = options.exclude_fields exclude_fields = options.exclude_fields
inspected_model = sqlalchemyinspect(options.model) inspected_model = sqlalchemyinspect(options.model)
fields = OrderedDict() fields = OrderedDict()
print('options in construct_fields', options)
for name, column in inspected_model.columns.items(): for name, column in inspected_model.columns.items():
is_not_in_only = only_fields and name not in only_fields is_not_in_only = only_fields and name not in only_fields
@ -56,7 +57,7 @@ def construct_fields(options):
# We skip this field if we specify only_fields and is not # We skip this field if we specify only_fields and is not
# in there. Or when we excldue this field in exclude_fields # in there. Or when we excldue this field in exclude_fields
continue continue
converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry) converted_relationship = convert_sqlalchemy_relationship(relationship, options.registry, options.connections, type_name)
name = relationship.key name = relationship.key
fields[name] = converted_relationship fields[name] = converted_relationship
@ -82,7 +83,8 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
exclude_fields=(), exclude_fields=(),
id='id', id='id',
interfaces=(), interfaces=(),
registry=None registry=None,
connections={},
) )
if not options.registry: if not options.registry:
@ -96,13 +98,12 @@ class SQLAlchemyObjectTypeMeta(ObjectTypeMeta):
'{}.Meta, received "{}".' '{}.Meta, received "{}".'
).format(name, options.model) ).format(name, options.model)
cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = ObjectTypeMeta.__new__(cls, name, bases, dict(attrs, _meta=options))
options.registry.register(cls) options.registry.register(cls)
options.sqlalchemy_fields = yank_fields_from_attrs( options.sqlalchemy_fields = yank_fields_from_attrs(
construct_fields(options), construct_fields(options, name),
_as=Field, _as=Field,
) )
options.fields = merge( options.fields = merge(

View File

@ -2,17 +2,18 @@ import re
from collections import Iterable, OrderedDict from collections import Iterable, OrderedDict
from functools import partial from functools import partial
from promise import Promise
import six import six
from graphql_relay import connection_from_list from graphql_relay import connection_from_list
from ..types import Boolean, Int, List, String, AbstractType from ..types import Boolean, Int, List, String, AbstractType, Dynamic
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options from ..types.options import Options
from ..utils.is_base_type import is_base_type from ..utils.is_base_type import is_base_type
from ..utils.props import props from ..utils.props import props
from .node import Node, is_node from .node import Node
class PageInfo(ObjectType): class PageInfo(ObjectType):
@ -55,47 +56,79 @@ class ConnectionMeta(ObjectTypeMeta):
) )
options.interfaces = () options.interfaces = ()
options.local_fields = OrderedDict() options.local_fields = OrderedDict()
base_name = re.sub('Connection$', '', name)
if attrs.get('edges'):
edges = attrs.get('edges')
edge = edges.of_type
else:
assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__) assert options.node, 'You have to provide a node in {}.Meta'.format(cls.__name__)
assert issubclass(options.node, (Node, ObjectType)), ( assert issubclass(options.node, (Node, ObjectType)), (
'Received incompatible node "{}" for Connection {}.' 'Received incompatible node "{}" for Connection {}.'
).format(options.node, name) ).format(options.node, name)
base_name = re.sub('Connection$', '', name) base_name = re.sub('Connection$', '', name)
if not options.name:
options.name = '{}Connection'.format(base_name)
edge_class = attrs.pop('Edge', None) edge_class = attrs.pop('Edge', None)
class EdgeBase(AbstractType): edge_attrs = {
node = Field(options.node, description='The item at the end of the edge') 'node': Field(
cursor = String(required=True, description='A cursor for use in pagination') options.node, description='The item at the end of the edge'),
'cursor': Edge._meta.fields['cursor']
}
edge_name = '{}Edge'.format(base_name) edge_name = '{}Edge'.format(base_name)
if edge_class and issubclass(edge_class, AbstractType): if edge_class and issubclass(edge_class, AbstractType):
edge = type(edge_name, (EdgeBase, edge_class, ObjectType, ), {}) edge = type(edge_name, (edge_class, ObjectType, ), edge_attrs)
else: else:
edge_attrs = props(edge_class) if edge_class else {} additional_attrs = props(edge_class) if edge_class else {}
edge = type(edge_name, (EdgeBase, ObjectType, ), edge_attrs) edge_attrs.update(additional_attrs)
edge = type(edge_name, (ObjectType, ), edge_attrs)
class ConnectionBase(AbstractType):
page_info = Field(PageInfo, name='pageInfo', required=True)
edges = List(edge) edges = List(edge)
bases = (ConnectionBase, ) + bases if not options.name:
options.name = '{}Connection'.format(base_name)
attrs.update({
'page_info': Field(PageInfo, name='pageInfo', required=True),
'edges': edges,
})
attrs = dict(attrs, _meta=options, Edge=edge) attrs = dict(attrs, _meta=options, Edge=edge)
return ObjectTypeMeta.__new__(cls, name, bases, attrs) return ObjectTypeMeta.__new__(cls, name, bases, attrs)
class Connection(six.with_metaclass(ConnectionMeta, ObjectType)): class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
pass
@classmethod
def for_type(cls, gql_type):
connection_name = '{}Connection'.format(gql_type._meta.name)
class Meta(object):
node = gql_type
return type(connection_name, (Connection, ), {'Meta': Meta})
class Edge(AbstractType):
cursor = String(required=True, description='A cursor for use in pagination')
def is_connection(gql_type):
'''Checks if a type is a connection. Taken directly from the spec definition:
https://facebook.github.io/relay/graphql/connections.htm#sec-Connection-Types'''
return gql_type._meta.name.endswith('Connection') if hasattr(gql_type, '_meta') else False
class IterableConnectionField(Field): class IterableConnectionField(Field):
def __init__(self, type, *args, **kwargs): def __init__(self, gql_type, *args, **kwargs):
assert is_connection(gql_type) or isinstance(gql_type, Dynamic), (
'The provided type "{}" for this ConnectionField has to be a Connection as defined by the Relay'
' spec.'.format(gql_type)
)
super(IterableConnectionField, self).__init__( super(IterableConnectionField, self).__init__(
type, gql_type,
*args, *args,
before=String(), before=String(),
after=String(), after=String(),
@ -106,32 +139,39 @@ class IterableConnectionField(Field):
@property @property
def type(self): def type(self):
type = super(IterableConnectionField, self).type gql_type = super(IterableConnectionField, self).type
if is_node(type): if isinstance(gql_type, Dynamic):
connection_type = type.Connection return gql_type.get_type()
else: else:
connection_type = type return gql_type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection. Received "{}".'
).format(str(self), connection_type)
return connection_type
@staticmethod @staticmethod
def connection_resolver(resolver, connection, root, args, context, info): def connection_resolver(resolver, connection, root, args, context, info):
iterable = resolver(root, args, context, info) resolved = Promise.resolve(resolver(root, args, context, info))
assert isinstance(iterable, Iterable), (
def handle_connection_and_list(result):
if isinstance(result, connection):
return result
elif is_connection(result):
raise AssertionError('Resolved value from the connection field has to be a {}. '
'Received {}.'.format(connection, type(result)))
else:
assert isinstance(result, Iterable), (
'Resolved value from the connection field have to be iterable. ' 'Resolved value from the connection field have to be iterable. '
'Received "{}"' 'Received "{}"'
).format(iterable) ).format(result)
connection = connection_from_list(
iterable, resolved_connection = connection_from_list(
result,
args, args,
connection_type=connection, connection_type=connection,
edge_type=connection.Edge, edge_type=connection.Edge,
pageinfo_type=PageInfo pageinfo_type=PageInfo
) )
connection.iterable = iterable resolved_connection.iterable = result
return connection return resolved_connection
return resolved.then(handle_connection_and_list)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
resolver = super(IterableConnectionField, self).get_resolver(parent_resolver) resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)

View File

@ -21,18 +21,6 @@ def is_node(objecttype):
return False return False
def get_default_connection(cls):
from .connection import Connection
assert issubclass(cls, ObjectType), (
'Can only get connection type on implemented Nodes.'
)
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node, *args, **kwargs): def __init__(self, node, *args, **kwargs):
@ -100,11 +88,3 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod @classmethod
def to_global_id(cls, type, id): def to_global_id(cls, type, id):
return to_global_id(type, id) return to_global_id(type, id)
@classmethod
def implements(cls, objecttype):
get_connection = getattr(objecttype, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, objecttype)
objecttype.Connection = get_connection()

View File

@ -1,6 +1,5 @@
from ...types import Field, List, NonNull, ObjectType, String, AbstractType from ...types import Field, List, NonNull, ObjectType, String, AbstractType
from ..connection import Connection, PageInfo from ..connection import Connection, PageInfo, Edge
from ..node import Node from ..node import Node
@ -23,7 +22,52 @@ def test_connection():
assert MyObjectConnection._meta.name == 'MyObjectConnection' assert MyObjectConnection._meta.name == 'MyObjectConnection'
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['page_info', 'edges', 'extra'] assert list(fields.keys()) == ['extra', 'edges', 'page_info']
edge_field = fields['edges']
pageinfo_field = fields['page_info']
assert isinstance(edge_field, Field)
assert isinstance(edge_field.type, List)
assert edge_field.type.of_type == MyObjectConnection.Edge
assert isinstance(pageinfo_field, Field)
assert isinstance(pageinfo_field.type, NonNull)
assert pageinfo_field.type.of_type == PageInfo
def test_multiple_connection_edges_are_not_the_same():
class MyObjectConnection(Connection):
extra = String()
class Meta:
node = MyObject
class Edge:
other = String()
class MyOtherObjectConnection(Connection):
class Meta:
node = MyObject
class Edge:
other = String()
assert MyObjectConnection.Edge != MyOtherObjectConnection.Edge
assert MyObjectConnection.Edge._meta.name != MyOtherObjectConnection.Edge._meta.name
def test_create_connection_with_custom_edge_type():
class MyEdge(Edge):
node = Field(MyObject)
class MyObjectConnection(Connection):
extra = String()
edges = List(MyEdge)
assert MyObjectConnection.Edge == MyEdge
assert MyObjectConnection._meta.name == 'MyObjectConnection'
fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['extra', 'edges', 'page_info']
edge_field = fields['edges'] edge_field = fields['edges']
pageinfo_field = fields['page_info'] pageinfo_field = fields['page_info']
@ -46,7 +90,18 @@ def test_connection_inherit_abstracttype():
assert MyObjectConnection._meta.name == 'MyObjectConnection' assert MyObjectConnection._meta.name == 'MyObjectConnection'
fields = MyObjectConnection._meta.fields fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['page_info', 'edges', 'extra'] assert list(fields.keys()) == ['extra', 'edges', 'page_info']
def test_defaul_connection_for_type():
MyObjectConnection = Connection.for_type(MyObject)
assert MyObjectConnection._meta.name == 'MyObjectConnection'
fields = MyObjectConnection._meta.fields
assert list(fields.keys()) == ['edges', 'page_info']
def test_default_connection_for_type_does_not_returns_same_Connection():
assert Connection.for_type(MyObject) != Connection.for_type(MyObject)
def test_edge(): def test_edge():
@ -60,7 +115,7 @@ def test_edge():
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == 'MyObjectEdge'
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor', 'other'] assert list(edge_fields.keys()) == ['cursor', 'other', 'node']
assert isinstance(edge_fields['node'], Field) assert isinstance(edge_fields['node'], Field)
assert edge_fields['node'].type == MyObject assert edge_fields['node'].type == MyObject
@ -83,7 +138,7 @@ def test_edge_with_bases():
Edge = MyObjectConnection.Edge Edge = MyObjectConnection.Edge
assert Edge._meta.name == 'MyObjectEdge' assert Edge._meta.name == 'MyObjectEdge'
edge_fields = Edge._meta.fields edge_fields = Edge._meta.fields
assert list(edge_fields.keys()) == ['node', 'cursor', 'extra', 'other'] assert list(edge_fields.keys()) == ['extra', 'other', 'cursor', 'node']
assert isinstance(edge_fields['node'], Field) assert isinstance(edge_fields['node'], Field)
assert edge_fields['node'].type == MyObject assert edge_fields['node'].type == MyObject
@ -92,17 +147,37 @@ def test_edge_with_bases():
assert edge_fields['other'].type == String assert edge_fields['other'].type == String
def test_edge_on_node(): def test_pageinfo():
Edge = MyObject.Connection.Edge assert PageInfo._meta.name == 'PageInfo'
assert Edge._meta.name == 'MyObjectEdge' fields = PageInfo._meta.fields
edge_fields = Edge._meta.fields assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor']
assert list(edge_fields.keys()) == ['node', 'cursor']
def test_edge_for_node_type():
edge = Connection.for_type(MyObject).Edge
assert edge._meta.name == 'MyObjectEdge'
edge_fields = edge._meta.fields
assert list(edge_fields.keys()) == ['cursor', 'node']
assert isinstance(edge_fields['node'], Field) assert isinstance(edge_fields['node'], Field)
assert edge_fields['node'].type == MyObject assert edge_fields['node'].type == MyObject
def test_pageinfo(): def test_edge_for_object_type():
assert PageInfo._meta.name == 'PageInfo' class MyObject(ObjectType):
fields = PageInfo._meta.fields field = String()
assert list(fields.keys()) == ['has_next_page', 'has_previous_page', 'start_cursor', 'end_cursor']
edge = Connection.for_type(MyObject).Edge
assert edge._meta.name == 'MyObjectEdge'
edge_fields = edge._meta.fields
assert list(edge_fields.keys()) == ['cursor', 'node']
assert isinstance(edge_fields['node'], Field)
assert edge_fields['node'].type == MyObject
def test_edge_for_type_returns_same_edge():
MyObjectConnection = Connection.for_type(MyObject)
assert MyObjectConnection.Edge == MyObjectConnection.Edge

View File

@ -1,6 +1,8 @@
from collections import OrderedDict from collections import OrderedDict
from ..connection import ConnectionField from promise import Promise
from ..connection import ConnectionField, Connection, PageInfo
from ..node import Node from ..node import Node
from graphql_relay.utils import base64 from graphql_relay.utils import base64
from ...types import ObjectType, String, Schema from ...types import ObjectType, String, Schema
@ -15,12 +17,40 @@ class Letter(ObjectType):
letter = String() letter = String()
class Query(ObjectType): class MyLetterObjectConnection(Connection):
letters = ConnectionField(Letter) extra = String()
def resolve_letters(self, args, context, info): class Meta:
node = Letter
class Edge:
other = String()
LetterConnection = Connection.for_type(Letter)
class Query(ObjectType):
letters = ConnectionField(LetterConnection)
letters_wrong_connection = ConnectionField(LetterConnection)
letters_promise = ConnectionField(LetterConnection)
letters_connection = ConnectionField(MyLetterObjectConnection)
def resolve_letters(self, *_):
return list(letters.values()) return list(letters.values())
def resolve_letters_wrong_connection(self, *_):
return MyLetterObjectConnection()
def resolve_letters_connection(self, *_):
return MyLetterObjectConnection(
extra='1',
page_info=PageInfo(has_next_page=True, has_previous_page=False),
edges=[MyLetterObjectConnection.Edge(cursor='1', node=Letter(letter='hello'))]
)
def resolve_letters_promise(self, *_):
return Promise.resolve(list(letters.values()))
node = Node.Field() node = Node.Field()
@ -75,8 +105,7 @@ def execute(args=''):
''' % args) ''' % args)
def check(args, letters, has_previous_page=False, has_next_page=False): def create_expexted_result(letters, has_previous_page=False, has_next_page=False, field_name='letters'):
result = execute(args)
expected_edges = edges(letters) expected_edges = edges(letters)
expected_page_info = { expected_page_info = {
'hasPreviousPage': has_previous_page, 'hasPreviousPage': has_previous_page,
@ -84,16 +113,107 @@ def check(args, letters, has_previous_page=False, has_next_page=False):
'endCursor': expected_edges[-1]['cursor'] if expected_edges else None, 'endCursor': expected_edges[-1]['cursor'] if expected_edges else None,
'startCursor': expected_edges[0]['cursor'] if expected_edges else None 'startCursor': expected_edges[0]['cursor'] if expected_edges else None
} }
return {
assert not result.errors field_name: {
assert result.data == {
'letters': {
'edges': expected_edges, 'edges': expected_edges,
'pageInfo': expected_page_info 'pageInfo': expected_page_info
} }
} }
def check(args, letters, has_previous_page=False, has_next_page=False):
result = execute(args)
assert not result.errors
assert result.data == create_expexted_result(letters, has_previous_page, has_next_page)
def test_resolver_throws_error_on_returning_wrong_connection_type():
result = schema.execute('''
{
lettersWrongConnection {
edges {
node {
id
}
}
}
}
''')
assert result.errors[0].message == ('Resolved value from the connection field has to be a LetterConnection. '
'Received MyLetterObjectConnection.')
def test_resolver_handles_returned_connection_field_correctly():
result = schema.execute('''
{
lettersConnection {
extra
edges {
node {
id
letter
}
cursor
}
pageInfo {
hasPreviousPage
hasNextPage
startCursor
endCursor
}
}
}
''')
assert not result.errors
expected_result = {
'lettersConnection': {
'extra': '1',
'edges': [
{
'node': {
'id': 'TGV0dGVyOk5vbmU=',
'letter': 'hello',
},
'cursor': '1'
}
],
'pageInfo': {
'hasPreviousPage': False,
'hasNextPage': True,
'startCursor': None,
'endCursor': None,
}
}
}
assert result.data == expected_result
def test_resolver_handles_returned_promise_correctly():
result = schema.execute('''
{
lettersPromise {
edges {
node {
id
letter
}
cursor
}
pageInfo {
hasPreviousPage
hasNextPage
startCursor
endCursor
}
}
}
''')
assert not result.errors
assert result.data == create_expexted_result('ABCDE', field_name='lettersPromise')
def test_returns_all_elements_without_filters(): def test_returns_all_elements_without_filters():
check('', 'ABCDE') check('', 'ABCDE')

View File

@ -1,4 +1,3 @@
from collections import OrderedDict
import pytest import pytest
from ...types import (Argument, Field, InputField, InputObjectType, ObjectType, from ...types import (Argument, Field, InputField, InputObjectType, ObjectType,
@ -21,6 +20,9 @@ class MyNode(ObjectType):
name = String() name = String()
MyNodeConnection = Connection.for_type(MyNode)
class SaySomething(ClientIDMutation): class SaySomething(ClientIDMutation):
class Input: class Input:
@ -39,13 +41,13 @@ class OtherMutation(ClientIDMutation):
additional_field = String() additional_field = String()
name = String() name = String()
my_node_edge = Field(MyNode.Connection.Edge) my_node_edge = Field(MyNodeConnection.Edge)
@classmethod @classmethod
def mutate_and_get_payload(cls, args, context, info): def mutate_and_get_payload(cls, args, context, info):
shared = args.get('shared', '') shared = args.get('shared', '')
additionalField = args.get('additionalField', '') additionalField = args.get('additionalField', '')
edge_type = MyNode.Connection.Edge edge_type = MyNodeConnection.Edge
return OtherMutation(name=shared + additionalField, return OtherMutation(name=shared + additionalField,
my_node_edge=edge_type( my_node_edge=edge_type(
cursor='1', node=MyNode(name='name'))) cursor='1', node=MyNode(name='name')))

View File

@ -53,15 +53,6 @@ def test_node_good():
assert 'id' in MyNode._meta.fields assert 'id' in MyNode._meta.fields
def test_node_get_connection():
connection = MyNode.Connection
assert issubclass(connection, Connection)
def test_node_get_connection_dont_duplicate():
assert MyNode.Connection == MyNode.Connection
def test_node_query(): def test_node_query():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1) '{ node(id:"%s") { ... on MyNode { name } } }' % to_global_id("MyNode", 1)

View File

@ -16,7 +16,7 @@ def source_resolver(source, root, args, context, info):
class Field(OrderedType): class Field(OrderedType):
def __init__(self, type, args=None, resolver=None, source=None, def __init__(self, gql_type, args=None, resolver=None, source=None,
deprecation_reason=None, name=None, description=None, deprecation_reason=None, name=None, description=None,
required=False, _creation_counter=None, **extra_args): required=False, _creation_counter=None, **extra_args):
super(Field, self).__init__(_creation_counter=_creation_counter) super(Field, self).__init__(_creation_counter=_creation_counter)
@ -28,10 +28,10 @@ class Field(OrderedType):
) )
if required: if required:
type = NonNull(type) gql_type = NonNull(gql_type)
self.name = name self.name = name
self._type = type self._type = gql_type
self.args = to_arguments(args or OrderedDict(), extra_args) self.args = to_arguments(args or OrderedDict(), extra_args)
if source: if source:
resolver = partial(source_resolver, source) resolver = partial(source_resolver, source)

View File

@ -52,7 +52,3 @@ class Interface(six.with_metaclass(InterfaceMeta)):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
raise Exception("An Interface cannot be intitialized") raise Exception("An Interface cannot be intitialized")
@classmethod
def implements(cls, objecttype):
pass

View File

@ -46,9 +46,6 @@ class ObjectTypeMeta(AbstractTypeMeta):
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces:
interface.implements(cls)
return cls return cls
def __str__(cls): # noqa: N802 def __str__(cls): # noqa: N802

View File

@ -16,7 +16,7 @@ else:
# machinery. # machinery.
builtins.__SETUP__ = True builtins.__SETUP__ = True
version = __import__('graphene').get_version() version = "1.0.beta-1"
class PyTest(TestCommand): class PyTest(TestCommand):

View File

@ -5,8 +5,8 @@ skipsdist = true
[testenv] [testenv]
deps= deps=
pytest>=2.7.2 pytest>=2.7.2
graphql-core>=0.5.1 graphql-core>=1.0.dev
graphql-relay>=0.4.3 graphql-relay>=0.4.4
six six
blinker blinker
singledispatch singledispatch