Added ClientIDMutation. All examples working 💪

This commit is contained in:
Syrus Akbary 2016-08-13 23:00:25 -07:00
parent e2036da75f
commit 7804f10732
12 changed files with 311 additions and 64 deletions

View File

@ -21,11 +21,15 @@ class Character(graphene.Interface):
return [get_character(f) for f in self.friends] return [get_character(f) for f in self.friends]
class Human(graphene.ObjectType, Character): class Human(graphene.ObjectType):
class Meta:
interfaces = (Character, )
home_planet = graphene.String() home_planet = graphene.String()
class Droid(graphene.ObjectType, Character): class Droid(graphene.ObjectType):
class Meta:
interfaces = (Character, )
primary_function = graphene.String() primary_function = graphene.String()

View File

@ -5,9 +5,11 @@ from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
class Ship(graphene.ObjectType): class Ship(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
class Meta:
interfaces = (relay.Node, )
name = graphene.String(description='The name of the ship.') name = graphene.String(description='The name of the ship.')
@classmethod @classmethod
@ -15,8 +17,12 @@ class Ship(graphene.ObjectType):
return get_ship(id) return get_ship(id)
class Faction(relay.Node, graphene.ObjectType): class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
class Meta:
interfaces = (relay.Node, )
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.') ships = relay.ConnectionField(Ship, description='The ships used by the faction.')

View File

@ -4,6 +4,67 @@ from ..schema import schema
setup() setup()
def test_str_schema():
assert str(schema) == '''schema {
query: Query
mutation: Mutation
}
type Faction implements Node {
id: ID!
name: String
ships(before: String, after: String, first: Int, last: Int): ShipConnection
}
input IntroduceShipInput {
shipName: String!
factionId: String!
clientMutationId: String
}
type IntroduceShipPayload {
ship: Ship
faction: Faction
}
type Mutation {
introduceShip(input: IntroduceShipInput): IntroduceShipPayload
}
interface Node {
id: ID!
}
type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}
type Query {
rebels: Faction
empire: Faction
node(id: ID!): Node
}
type Ship implements Node {
id: ID!
name: String
}
type ShipConnection {
pageInfo: PageInfo!
edges: [ShipEdge]
}
type ShipEdge {
node: Ship
cursor: String!
}
'''
def test_correctly_fetches_id_name_rebels(): def test_correctly_fetches_id_name_rebels():
query = ''' query = '''
query RebelsQuery { query RebelsQuery {

View File

@ -1,10 +1,10 @@
from .node import Node from .node import Node
# from .mutation import ClientIDMutation from .mutation import ClientIDMutation
from .connection import Connection, ConnectionField from .connection import Connection, ConnectionField
__all__ = [ __all__ = [
'Node', 'Node',
# 'ClientIDMutation', 'ClientIDMutation',
'Connection', 'Connection',
'ConnectionField', 'ConnectionField',
] ]

View File

@ -5,9 +5,8 @@ from functools import partial
import six import six
from graphql_relay import connection_from_list from graphql_relay import connection_from_list
from graphql_relay.connection.connection import connection_args
from ..types import Boolean, String, List from ..types import Boolean, String, List, Int
from ..types.field import Field from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options from ..types.options import Options
@ -15,7 +14,7 @@ from ..utils.is_base_type import is_base_type
from ..utils.props import props from ..utils.props import props
from .node import Node from .node import Node
from ..types.utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs from ..types.utils import get_fields_in_type, yank_fields_from_attrs
def is_node(objecttype): def is_node(objecttype):
@ -81,9 +80,14 @@ class ConnectionMeta(ObjectTypeMeta):
('cursor', Field(String, required=True, description='A cursor for use in pagination')) ('cursor', Field(String, required=True, description='A cursor for use in pagination'))
]) ])
edge_attrs = props(edge_class) if edge_class else OrderedDict() edge_attrs = props(edge_class) if edge_class else OrderedDict()
edge_fields.update(get_fields_in_type(ObjectType, edge_attrs)) extended_edge_fields = get_fields_in_type(ObjectType, edge_attrs)
EdgeMeta = type('Meta', (object, ), {'fields': edge_fields}) edge_fields.update(extended_edge_fields)
Edge = type('{}Edge'.format(base_name), (ObjectType,), {'Meta': EdgeMeta}) EdgeMeta = type('Meta', (object, ), {
'fields': edge_fields,
'name': '{}Edge'.format(base_name)
})
yank_fields_from_attrs(edge_attrs, extended_edge_fields)
Edge = type('Edge', (ObjectType,), dict(edge_attrs, Meta=EdgeMeta))
options.local_fields = OrderedDict([ options.local_fields = OrderedDict([
('page_info', Field(PageInfo, name='pageInfo', required=True)), ('page_info', Field(PageInfo, name='pageInfo', required=True)),
@ -101,9 +105,10 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
resolve_node = None resolve_node = None
resolve_cursor = None resolve_cursor = None
def __init__(self, *args, **kwargs): # def __init__(self, *args, **kwargs):
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info')) # super(Connection, self).__init__(*args, **kwargs)
super(Connection, self).__init__(*args, **kwargs) # print args, kwargs
# print dir(self.page_info)
class IterableConnectionField(Field): class IterableConnectionField(Field):
@ -118,17 +123,21 @@ class IterableConnectionField(Field):
super(IterableConnectionField, self).__init__( super(IterableConnectionField, self).__init__(
type, type,
args=connection_args,
*args, *args,
before=String(),
after=String(),
first=Int(),
last=Int(),
**kwargs **kwargs
) )
@property @property
def connection(self): def type(self):
if is_node(self.type): type = super(IterableConnectionField, self).type
connection_type = self.type.Connection if is_node(type):
connection_type = type.Connection
else: else:
connection_type = self.type connection_type = type
assert issubclass(connection_type, Connection), ( assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection' '{} type have to be a subclass of Connection'
).format(str(self)) ).format(str(self))
@ -146,10 +155,11 @@ class IterableConnectionField(Field):
args, args,
connection_type=connection, connection_type=connection,
edge_type=connection.Edge, edge_type=connection.Edge,
pageinfo_type=PageInfo
) )
return connection return connection
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.connection) return partial(self.connection_resolver, parent_resolver, self.type)
ConnectionField = IterableConnectionField ConnectionField = IterableConnectionField

View File

@ -0,0 +1,54 @@
from functools import partial
import six
import re
from promise import Promise
from ..utils.is_base_type import is_base_type
from ..utils.props import props
from ..types import Field, String, InputObjectType, Argument
from ..types.objecttype import ObjectType, ObjectTypeMeta
class ClientIDMutationMeta(ObjectTypeMeta):
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of
# Mutation
if not is_base_type(bases, ClientIDMutationMeta):
return type.__new__(cls, name, bases, attrs)
input_class = attrs.pop('Input', None)
base_name = re.sub('Payload$', '', name)
cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs)
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
assert mutate_and_get_payload, (
"{}.mutate_and_get_payload method is required"
" in a ClientIDMutation."
).format(name)
input_attrs = props(input_class) if input_class else {}
input_attrs['client_mutation_id'] = String(name='clientMutationId')
cls.Input = type('{}Input'.format(base_name), (InputObjectType,), input_attrs)
cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input))
return cls
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, ObjectType)):
@classmethod
def mutate(cls, root, args, context, info):
input = args.get('input')
def on_resolve(payload):
try:
payload.clientMutationId = input['clientMutationId']
except:
raise Exception((
'Cannot set clientMutationId in the payload object {}'
).format(repr(payload)))
return payload
return Promise.resolve(
cls.mutate_and_get_payload(input, context, info)
).then(on_resolve)

View File

@ -1,17 +1,55 @@
import six
from collections import OrderedDict
from functools import partial
from graphql_relay import from_global_id, to_global_id from graphql_relay import from_global_id, to_global_id
from ..types import Interface, ID, Field from ..types import ObjectType, Interface, ID, Field
from ..types.interface import InterfaceMeta
class Node(Interface): def get_default_connection(cls):
from .connection import Connection
assert issubclass(cls, ObjectType), (
'Can only get connection type on implemented Nodes.'
)
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class GlobalID(Field):
def __init__(self, node, *args, **kwargs):
super(GlobalID, self).__init__(ID, *args, **kwargs)
self.node = node
@staticmethod
def id_resolver(parent_resolver, node, root, args, context, info):
id = parent_resolver(root, args, context, info)
# type_name = root._meta.name # info.parent_type.name
return node.to_global_id(info.parent_type.name, id)
def get_resolver(self, parent_resolver):
return partial(self.id_resolver, parent_resolver, self.node)
class NodeMeta(InterfaceMeta):
def __new__(cls, name, bases, attrs):
cls = InterfaceMeta.__new__(cls, name, bases, attrs)
cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.')
# new_fields = OrderedDict([
# ('id', GlobalID(cls, required=True, description='The ID of the object.'))
# ])
# new_fields.update(cls._meta.fields)
# cls._meta.fields = new_fields
return cls
class Node(six.with_metaclass(NodeMeta, Interface)):
'''An object with an ID''' '''An object with an ID'''
id = ID(required=True, description='The ID of the object.')
@classmethod
def resolve_id(cls, root, args, context, info):
type_name = root._meta.name # info.parent_type.name
return cls.to_global_id(type_name, getattr(root, 'id', None))
@classmethod @classmethod
def Field(cls): def Field(cls):
def resolve_node(root, args, context, info): def resolve_node(root, args, context, info):
@ -45,13 +83,13 @@ class Node(Interface):
@classmethod @classmethod
def implements(cls, objecttype): def implements(cls, objecttype):
require_get_node = Node in objecttype._meta.interfaces require_get_node = cls.get_node_from_global_id == Node.get_node_from_global_id
# get_connection = getattr(objecttype, 'get_connection', None) get_connection = getattr(objecttype, 'get_connection', None)
# if not get_connection: if not get_connection:
# get_connection = partial(get_default_connection, objecttype) get_connection = partial(get_default_connection, objecttype)
# objecttype.Connection = get_connection() objecttype.Connection = get_connection()
if require_get_node: if require_get_node:
assert hasattr( assert hasattr(objecttype, 'get_node'), (
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format( '{}.get_node method is required by the Node interface.'
objecttype.__name__) ).format(objecttype.__name__)

View File

@ -0,0 +1,66 @@
import pytest
from ...types import ObjectType, Schema, Field, InputField, InputObjectType, Argument
from ...types.scalars import String
from ..mutation import ClientIDMutation
class SaySomething(ClientIDMutation):
class Input:
what = String()
phrase = String()
@staticmethod
def mutate_and_get_payload(args, context, info):
what = args.get('what')
return SaySomething(phrase=str(what))
class RootQuery(ObjectType):
something = String()
class Mutation(ObjectType):
say = SaySomething.Field()
schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation):
pass
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation." == str(
excinfo.value)
def test_mutation():
fields = SaySomething._meta.fields
assert fields.keys() == ['phrase']
assert isinstance(fields['phrase'], Field)
field = SaySomething.Field()
assert field.type == SaySomething
assert field.args.keys() == ['input']
assert isinstance(field.args['input'], Argument)
assert field.args['input'].type == SaySomething.Input
def test_mutation_input():
Input = SaySomething.Input
assert issubclass(Input, InputObjectType)
fields = Input._meta.fields
assert fields.keys() == ['what', 'client_mutation_id']
assert isinstance(fields['what'], InputField)
assert fields['what'].type == String
assert isinstance(fields['client_mutation_id'], InputField)
assert fields['client_mutation_id'].type == String
# def test_node_query():
# executed = schema.execute(
# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
# )
# assert not executed.errors
# assert executed.data == {'say': {'phrase': 'hello'}}

View File

@ -3,8 +3,7 @@ import pytest
from graphql_relay import to_global_id from graphql_relay import to_global_id
from ...types import ObjectType, Schema, String from ...types import ObjectType, Schema, String
# from ...types.scalars import String from ..connection import Connection
# from ..connection import Connection
from ..node import Node from ..node import Node
@ -47,13 +46,13 @@ def test_node_good():
assert 'id' in MyNode._meta.fields assert 'id' in MyNode._meta.fields
# def test_node_get_connection(): def test_node_get_connection():
# connection = MyNode.Connection connection = MyNode.Connection
# assert issubclass(connection, Connection) assert issubclass(connection, Connection)
# def test_node_get_connection_dont_duplicate(): def test_node_get_connection_dont_duplicate():
# assert MyNode.Connection == MyNode.Connection assert MyNode.Connection == MyNode.Connection
def test_node_query(): def test_node_query():

View File

@ -5,6 +5,7 @@ from .options import Options
from .abstracttype import AbstractTypeMeta from .abstracttype import AbstractTypeMeta
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
from .unmountedtype import UnmountedType
class InputObjectTypeMeta(AbstractTypeMeta): class InputObjectTypeMeta(AbstractTypeMeta):
@ -19,11 +20,13 @@ class InputObjectTypeMeta(AbstractTypeMeta):
attrs.pop('Meta', None), attrs.pop('Meta', None),
name=name, name=name,
description=attrs.get('__doc__'), description=attrs.get('__doc__'),
fields=None,
) )
attrs = merge_fields_in_attrs(bases, attrs) attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(InputObjectType, attrs) if not options.fields:
yank_fields_from_attrs(attrs, options.fields) options.fields = get_fields_in_type(InputObjectType, attrs)
yank_fields_from_attrs(attrs, options.fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options)) return type.__new__(cls, name, bases, dict(attrs, _meta=options))
@ -31,6 +34,9 @@ class InputObjectTypeMeta(AbstractTypeMeta):
return cls._meta.name return cls._meta.name
class InputObjectType(six.with_metaclass(InputObjectTypeMeta)): class InputObjectType(six.with_metaclass(InputObjectTypeMeta, UnmountedType)):
def __init__(self, *args, **kwargs): @classmethod
raise Exception("An InputObjectType cannot be intitialized") def get_type(cls):
return cls
# def __init__(self, *args, **kwargs):
# raise Exception("An InputObjectType cannot be intitialized")

View File

@ -19,11 +19,13 @@ class InterfaceMeta(AbstractTypeMeta):
attrs.pop('Meta', None), attrs.pop('Meta', None),
name=name, name=name,
description=attrs.get('__doc__'), description=attrs.get('__doc__'),
fields=None,
) )
attrs = merge_fields_in_attrs(bases, attrs) attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(Interface, attrs) if not options.fields:
yank_fields_from_attrs(attrs, options.fields) options.fields = get_fields_in_type(Interface, attrs)
yank_fields_from_attrs(attrs, options.fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options)) return type.__new__(cls, name, bases, dict(attrs, _meta=options))

View File

@ -22,20 +22,21 @@ class ObjectTypeMeta(AbstractTypeMeta):
name=name, name=name,
description=attrs.get('__doc__'), description=attrs.get('__doc__'),
interfaces=(), interfaces=(),
fields=OrderedDict(), fields=None,
) )
attrs = merge_fields_in_attrs(bases, attrs) attrs = merge_fields_in_attrs(bases, attrs)
options.local_fields = get_fields_in_type(ObjectType, attrs) if not options.fields:
yank_fields_from_attrs(attrs, options.local_fields) options.local_fields = get_fields_in_type(ObjectType, attrs)
options.interface_fields = OrderedDict() yank_fields_from_attrs(attrs, options.local_fields)
for interface in options.interfaces: options.interface_fields = OrderedDict()
assert issubclass(interface, Interface), ( for interface in options.interfaces:
'All interfaces of {} must be a subclass of Interface. Received "{}".' assert issubclass(interface, Interface), (
).format(name, interface) 'All interfaces of {} must be a subclass of Interface. Received "{}".'
options.interface_fields.update(interface._meta.fields) ).format(name, interface)
options.fields.update(options.interface_fields) options.interface_fields.update(interface._meta.fields)
options.fields.update(options.local_fields) options.fields = OrderedDict(options.interface_fields)
options.fields.update(options.local_fields)
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options)) cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces: for interface in options.interfaces: