Added ClientIDMutation. All examples working 💪

This commit is contained in:
Syrus Akbary 2016-08-13 23:00:25 -07:00
parent e2036da75f
commit 7804f10732
12 changed files with 311 additions and 64 deletions

View File

@ -21,11 +21,15 @@ class Character(graphene.Interface):
return [get_character(f) for f in self.friends]
class Human(graphene.ObjectType, Character):
class Human(graphene.ObjectType):
class Meta:
interfaces = (Character, )
home_planet = graphene.String()
class Droid(graphene.ObjectType, Character):
class Droid(graphene.ObjectType):
class Meta:
interfaces = (Character, )
primary_function = graphene.String()

View File

@ -5,9 +5,11 @@ from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
class Ship(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
'''A ship in the Star Wars saga'''
class Meta:
interfaces = (relay.Node, )
name = graphene.String(description='The name of the ship.')
@classmethod
@ -15,8 +17,12 @@ class Ship(graphene.ObjectType):
return get_ship(id)
class Faction(relay.Node, graphene.ObjectType):
class Faction(graphene.ObjectType):
'''A faction in the Star Wars saga'''
class Meta:
interfaces = (relay.Node, )
name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField(Ship, description='The ships used by the faction.')

View File

@ -4,6 +4,67 @@ from ..schema import schema
setup()
def test_str_schema():
assert str(schema) == '''schema {
query: Query
mutation: Mutation
}
type Faction implements Node {
id: ID!
name: String
ships(before: String, after: String, first: Int, last: Int): ShipConnection
}
input IntroduceShipInput {
shipName: String!
factionId: String!
clientMutationId: String
}
type IntroduceShipPayload {
ship: Ship
faction: Faction
}
type Mutation {
introduceShip(input: IntroduceShipInput): IntroduceShipPayload
}
interface Node {
id: ID!
}
type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}
type Query {
rebels: Faction
empire: Faction
node(id: ID!): Node
}
type Ship implements Node {
id: ID!
name: String
}
type ShipConnection {
pageInfo: PageInfo!
edges: [ShipEdge]
}
type ShipEdge {
node: Ship
cursor: String!
}
'''
def test_correctly_fetches_id_name_rebels():
query = '''
query RebelsQuery {

View File

@ -1,10 +1,10 @@
from .node import Node
# from .mutation import ClientIDMutation
from .mutation import ClientIDMutation
from .connection import Connection, ConnectionField
__all__ = [
'Node',
# 'ClientIDMutation',
'ClientIDMutation',
'Connection',
'ConnectionField',
]

View File

@ -5,9 +5,8 @@ from functools import partial
import six
from graphql_relay import connection_from_list
from graphql_relay.connection.connection import connection_args
from ..types import Boolean, String, List
from ..types import Boolean, String, List, Int
from ..types.field import Field
from ..types.objecttype import ObjectType, ObjectTypeMeta
from ..types.options import Options
@ -15,7 +14,7 @@ from ..utils.is_base_type import is_base_type
from ..utils.props import props
from .node import Node
from ..types.utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
from ..types.utils import get_fields_in_type, yank_fields_from_attrs
def is_node(objecttype):
@ -81,9 +80,14 @@ class ConnectionMeta(ObjectTypeMeta):
('cursor', Field(String, required=True, description='A cursor for use in pagination'))
])
edge_attrs = props(edge_class) if edge_class else OrderedDict()
edge_fields.update(get_fields_in_type(ObjectType, edge_attrs))
EdgeMeta = type('Meta', (object, ), {'fields': edge_fields})
Edge = type('{}Edge'.format(base_name), (ObjectType,), {'Meta': EdgeMeta})
extended_edge_fields = get_fields_in_type(ObjectType, edge_attrs)
edge_fields.update(extended_edge_fields)
EdgeMeta = type('Meta', (object, ), {
'fields': edge_fields,
'name': '{}Edge'.format(base_name)
})
yank_fields_from_attrs(edge_attrs, extended_edge_fields)
Edge = type('Edge', (ObjectType,), dict(edge_attrs, Meta=EdgeMeta))
options.local_fields = OrderedDict([
('page_info', Field(PageInfo, name='pageInfo', required=True)),
@ -101,9 +105,10 @@ class Connection(six.with_metaclass(ConnectionMeta, ObjectType)):
resolve_node = None
resolve_cursor = None
def __init__(self, *args, **kwargs):
kwargs['pageInfo'] = kwargs.pop('pageInfo', kwargs.pop('page_info'))
super(Connection, self).__init__(*args, **kwargs)
# def __init__(self, *args, **kwargs):
# super(Connection, self).__init__(*args, **kwargs)
# print args, kwargs
# print dir(self.page_info)
class IterableConnectionField(Field):
@ -118,17 +123,21 @@ class IterableConnectionField(Field):
super(IterableConnectionField, self).__init__(
type,
args=connection_args,
*args,
before=String(),
after=String(),
first=Int(),
last=Int(),
**kwargs
)
@property
def connection(self):
if is_node(self.type):
connection_type = self.type.Connection
def type(self):
type = super(IterableConnectionField, self).type
if is_node(type):
connection_type = type.Connection
else:
connection_type = self.type
connection_type = type
assert issubclass(connection_type, Connection), (
'{} type have to be a subclass of Connection'
).format(str(self))
@ -146,10 +155,11 @@ class IterableConnectionField(Field):
args,
connection_type=connection,
edge_type=connection.Edge,
pageinfo_type=PageInfo
)
return connection
def get_resolver(self, parent_resolver):
return partial(self.connection_resolver, parent_resolver, self.connection)
return partial(self.connection_resolver, parent_resolver, self.type)
ConnectionField = IterableConnectionField

View File

@ -0,0 +1,54 @@
from functools import partial
import six
import re
from promise import Promise
from ..utils.is_base_type import is_base_type
from ..utils.props import props
from ..types import Field, String, InputObjectType, Argument
from ..types.objecttype import ObjectType, ObjectTypeMeta
class ClientIDMutationMeta(ObjectTypeMeta):
def __new__(cls, name, bases, attrs):
# Also ensure initialization is only performed for subclasses of
# Mutation
if not is_base_type(bases, ClientIDMutationMeta):
return type.__new__(cls, name, bases, attrs)
input_class = attrs.pop('Input', None)
base_name = re.sub('Payload$', '', name)
cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs)
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
assert mutate_and_get_payload, (
"{}.mutate_and_get_payload method is required"
" in a ClientIDMutation."
).format(name)
input_attrs = props(input_class) if input_class else {}
input_attrs['client_mutation_id'] = String(name='clientMutationId')
cls.Input = type('{}Input'.format(base_name), (InputObjectType,), input_attrs)
cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input))
return cls
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, ObjectType)):
@classmethod
def mutate(cls, root, args, context, info):
input = args.get('input')
def on_resolve(payload):
try:
payload.clientMutationId = input['clientMutationId']
except:
raise Exception((
'Cannot set clientMutationId in the payload object {}'
).format(repr(payload)))
return payload
return Promise.resolve(
cls.mutate_and_get_payload(input, context, info)
).then(on_resolve)

View File

@ -1,17 +1,55 @@
import six
from collections import OrderedDict
from functools import partial
from graphql_relay import from_global_id, to_global_id
from ..types import Interface, ID, Field
from ..types import ObjectType, Interface, ID, Field
from ..types.interface import InterfaceMeta
class Node(Interface):
def get_default_connection(cls):
from .connection import Connection
assert issubclass(cls, ObjectType), (
'Can only get connection type on implemented Nodes.'
)
class Meta:
node = cls
return type('{}Connection'.format(cls.__name__), (Connection,), {'Meta': Meta})
class GlobalID(Field):
def __init__(self, node, *args, **kwargs):
super(GlobalID, self).__init__(ID, *args, **kwargs)
self.node = node
@staticmethod
def id_resolver(parent_resolver, node, root, args, context, info):
id = parent_resolver(root, args, context, info)
# type_name = root._meta.name # info.parent_type.name
return node.to_global_id(info.parent_type.name, id)
def get_resolver(self, parent_resolver):
return partial(self.id_resolver, parent_resolver, self.node)
class NodeMeta(InterfaceMeta):
def __new__(cls, name, bases, attrs):
cls = InterfaceMeta.__new__(cls, name, bases, attrs)
cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.')
# new_fields = OrderedDict([
# ('id', GlobalID(cls, required=True, description='The ID of the object.'))
# ])
# new_fields.update(cls._meta.fields)
# cls._meta.fields = new_fields
return cls
class Node(six.with_metaclass(NodeMeta, Interface)):
'''An object with an ID'''
id = ID(required=True, description='The ID of the object.')
@classmethod
def resolve_id(cls, root, args, context, info):
type_name = root._meta.name # info.parent_type.name
return cls.to_global_id(type_name, getattr(root, 'id', None))
@classmethod
def Field(cls):
def resolve_node(root, args, context, info):
@ -45,13 +83,13 @@ class Node(Interface):
@classmethod
def implements(cls, objecttype):
require_get_node = Node in objecttype._meta.interfaces
# get_connection = getattr(objecttype, 'get_connection', None)
# if not get_connection:
# get_connection = partial(get_default_connection, objecttype)
require_get_node = cls.get_node_from_global_id == Node.get_node_from_global_id
get_connection = getattr(objecttype, 'get_connection', None)
if not get_connection:
get_connection = partial(get_default_connection, objecttype)
# objecttype.Connection = get_connection()
objecttype.Connection = get_connection()
if require_get_node:
assert hasattr(
objecttype, 'get_node'), '{}.get_node method is required by the Node interface.'.format(
objecttype.__name__)
assert hasattr(objecttype, 'get_node'), (
'{}.get_node method is required by the Node interface.'
).format(objecttype.__name__)

View File

@ -0,0 +1,66 @@
import pytest
from ...types import ObjectType, Schema, Field, InputField, InputObjectType, Argument
from ...types.scalars import String
from ..mutation import ClientIDMutation
class SaySomething(ClientIDMutation):
class Input:
what = String()
phrase = String()
@staticmethod
def mutate_and_get_payload(args, context, info):
what = args.get('what')
return SaySomething(phrase=str(what))
class RootQuery(ObjectType):
something = String()
class Mutation(ObjectType):
say = SaySomething.Field()
schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation):
pass
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation." == str(
excinfo.value)
def test_mutation():
fields = SaySomething._meta.fields
assert fields.keys() == ['phrase']
assert isinstance(fields['phrase'], Field)
field = SaySomething.Field()
assert field.type == SaySomething
assert field.args.keys() == ['input']
assert isinstance(field.args['input'], Argument)
assert field.args['input'].type == SaySomething.Input
def test_mutation_input():
Input = SaySomething.Input
assert issubclass(Input, InputObjectType)
fields = Input._meta.fields
assert fields.keys() == ['what', 'client_mutation_id']
assert isinstance(fields['what'], InputField)
assert fields['what'].type == String
assert isinstance(fields['client_mutation_id'], InputField)
assert fields['client_mutation_id'].type == String
# def test_node_query():
# executed = schema.execute(
# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
# )
# assert not executed.errors
# assert executed.data == {'say': {'phrase': 'hello'}}

View File

@ -3,8 +3,7 @@ import pytest
from graphql_relay import to_global_id
from ...types import ObjectType, Schema, String
# from ...types.scalars import String
# from ..connection import Connection
from ..connection import Connection
from ..node import Node
@ -47,13 +46,13 @@ def test_node_good():
assert 'id' in MyNode._meta.fields
# def test_node_get_connection():
# connection = MyNode.Connection
# assert issubclass(connection, Connection)
def test_node_get_connection():
connection = MyNode.Connection
assert issubclass(connection, Connection)
# def test_node_get_connection_dont_duplicate():
# assert MyNode.Connection == MyNode.Connection
def test_node_get_connection_dont_duplicate():
assert MyNode.Connection == MyNode.Connection
def test_node_query():

View File

@ -5,6 +5,7 @@ from .options import Options
from .abstracttype import AbstractTypeMeta
from .utils import get_fields_in_type, yank_fields_from_attrs, merge_fields_in_attrs
from .unmountedtype import UnmountedType
class InputObjectTypeMeta(AbstractTypeMeta):
@ -19,11 +20,13 @@ class InputObjectTypeMeta(AbstractTypeMeta):
attrs.pop('Meta', None),
name=name,
description=attrs.get('__doc__'),
fields=None,
)
attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(InputObjectType, attrs)
yank_fields_from_attrs(attrs, options.fields)
if not options.fields:
options.fields = get_fields_in_type(InputObjectType, attrs)
yank_fields_from_attrs(attrs, options.fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options))
@ -31,6 +34,9 @@ class InputObjectTypeMeta(AbstractTypeMeta):
return cls._meta.name
class InputObjectType(six.with_metaclass(InputObjectTypeMeta)):
def __init__(self, *args, **kwargs):
raise Exception("An InputObjectType cannot be intitialized")
class InputObjectType(six.with_metaclass(InputObjectTypeMeta, UnmountedType)):
@classmethod
def get_type(cls):
return cls
# def __init__(self, *args, **kwargs):
# raise Exception("An InputObjectType cannot be intitialized")

View File

@ -19,11 +19,13 @@ class InterfaceMeta(AbstractTypeMeta):
attrs.pop('Meta', None),
name=name,
description=attrs.get('__doc__'),
fields=None,
)
attrs = merge_fields_in_attrs(bases, attrs)
options.fields = get_fields_in_type(Interface, attrs)
yank_fields_from_attrs(attrs, options.fields)
if not options.fields:
options.fields = get_fields_in_type(Interface, attrs)
yank_fields_from_attrs(attrs, options.fields)
return type.__new__(cls, name, bases, dict(attrs, _meta=options))

View File

@ -22,20 +22,21 @@ class ObjectTypeMeta(AbstractTypeMeta):
name=name,
description=attrs.get('__doc__'),
interfaces=(),
fields=OrderedDict(),
fields=None,
)
attrs = merge_fields_in_attrs(bases, attrs)
options.local_fields = get_fields_in_type(ObjectType, attrs)
yank_fields_from_attrs(attrs, options.local_fields)
options.interface_fields = OrderedDict()
for interface in options.interfaces:
assert issubclass(interface, Interface), (
'All interfaces of {} must be a subclass of Interface. Received "{}".'
).format(name, interface)
options.interface_fields.update(interface._meta.fields)
options.fields.update(options.interface_fields)
options.fields.update(options.local_fields)
if not options.fields:
options.local_fields = get_fields_in_type(ObjectType, attrs)
yank_fields_from_attrs(attrs, options.local_fields)
options.interface_fields = OrderedDict()
for interface in options.interfaces:
assert issubclass(interface, Interface), (
'All interfaces of {} must be a subclass of Interface. Received "{}".'
).format(name, interface)
options.interface_fields.update(interface._meta.fields)
options.fields = OrderedDict(options.interface_fields)
options.fields.update(options.local_fields)
cls = type.__new__(cls, name, bases, dict(attrs, _meta=options))
for interface in options.interfaces: