This commit is contained in:
Markus Padourek 2016-11-10 09:17:20 +00:00 committed by GitHub
commit 47d96fae12
4 changed files with 55 additions and 22 deletions

View File

@ -21,8 +21,8 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_class = attrs.pop('Input', None) input_class = attrs.pop('Input', None)
base_name = re.sub('Payload$', '', name) base_name = re.sub('Payload$', '', name)
if 'client_mutation_id' not in attrs: default_client_mutation_id = String(name='clientMutationId')
attrs['client_mutation_id'] = String(name='clientMutationId') attrs['client_mutation_id'] = attrs.get('client_mutation_id', default_client_mutation_id)
cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs) cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs)
mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None) mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None)
if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__: if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__:
@ -38,7 +38,7 @@ class ClientIDMutationMeta(ObjectTypeMeta):
input_attrs = props(input_class) input_attrs = props(input_class)
else: else:
bases += (input_class, ) bases += (input_class, )
input_attrs['client_mutation_id'] = String(name='clientMutationId') input_attrs['client_mutation_id'] = default_client_mutation_id
cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs) cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs)
cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True)) cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True))
return cls return cls

View File

@ -35,24 +35,26 @@ def get_default_connection(cls):
class GlobalID(Field): class GlobalID(Field):
def __init__(self, node, *args, **kwargs): def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs):
super(GlobalID, self).__init__(ID, *args, **kwargs) super(GlobalID, self).__init__(ID, required=required, *args, **kwargs)
self.node = node self._node = node or Node
self._parent_type_name = parent_type._meta.name if parent_type else None
@staticmethod @staticmethod
def id_resolver(parent_resolver, node, root, args, context, info): def id_resolver(parent_resolver, node, root, args, context, info, parent_type_name=None):
id = parent_resolver(root, args, context, info) type_id = parent_resolver(root, args, context, info)
return node.to_global_id(info.parent_type.name, id) # root._meta.name parent_type_name = parent_type_name or info.parent_type.name # root._meta.name
return node.to_global_id(parent_type_name, type_id)
def get_resolver(self, parent_resolver): def get_resolver(self, parent_resolver):
return partial(self.id_resolver, parent_resolver, self.node) return partial(self.id_resolver, parent_resolver, self._node, parent_type_name=self._parent_type_name)
class NodeMeta(InterfaceMeta): class NodeMeta(InterfaceMeta):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
cls = InterfaceMeta.__new__(cls, name, bases, attrs) cls = InterfaceMeta.__new__(cls, name, bases, attrs)
cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.') cls._meta.fields['id'] = GlobalID(cls, description='The ID of the object.')
return cls return cls

View File

@ -1,10 +1,12 @@
import pytest import pytest
from graphql_relay import to_global_id
from ...types import (AbstractType, Argument, Field, InputField, from ...types import (AbstractType, Argument, Field, InputField,
InputObjectType, NonNull, ObjectType, Schema) InputObjectType, NonNull, ObjectType, Schema)
from ...types.scalars import String from ...types.scalars import String
from ..mutation import ClientIDMutation from ..mutation import ClientIDMutation
from ..node import Node from ..node import GlobalID, Node
class SharedFields(AbstractType): class SharedFields(AbstractType):
@ -23,12 +25,14 @@ class SaySomething(ClientIDMutation):
class Input: class Input:
what = String() what = String()
phrase = String() phrase = String()
my_node_id = GlobalID(parent_type=MyNode)
@staticmethod @staticmethod
def mutate_and_get_payload(args, context, info): def mutate_and_get_payload(args, context, info):
what = args.get('what') what = args.get('what')
return SaySomething(phrase=str(what)) return SaySomething(phrase=str(what), my_node_id=1)
class OtherMutation(ClientIDMutation): class OtherMutation(ClientIDMutation):
@ -71,8 +75,9 @@ def test_no_mutate_and_get_payload():
def test_mutation(): def test_mutation():
fields = SaySomething._meta.fields fields = SaySomething._meta.fields
assert list(fields.keys()) == ['phrase', 'client_mutation_id'] assert list(fields.keys()) == ['phrase', 'my_node_id', 'client_mutation_id']
assert isinstance(fields['phrase'], Field) assert isinstance(fields['phrase'], Field)
assert isinstance(fields['my_node_id'], GlobalID)
field = SaySomething.Field() field = SaySomething.Field()
assert field.type == SaySomething assert field.type == SaySomething
assert list(field.args.keys()) == ['input'] assert list(field.args.keys()) == ['input']
@ -120,12 +125,13 @@ def test_subclassed_mutation_input():
assert fields['client_mutation_id'].type == String assert fields['client_mutation_id'].type == String
# def test_node_query(): def test_node_query():
# executed = schema.execute( executed = schema.execute(
# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }' 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase, clientMutationId, myNodeId} }'
# ) )
# assert not executed.errors assert not executed.errors
# assert executed.data == {'say': {'phrase': 'hello'}} assert dict(executed.data) == {'say': {'myNodeId': to_global_id('MyNode', '1'), 'clientMutationId': '1', 'phrase': 'hello'}}
def test_edge_query(): def test_edge_query():
executed = schema.execute( executed = schema.execute(

View File

@ -4,7 +4,7 @@ from graphql_relay import to_global_id
from ...types import AbstractType, ObjectType, Schema, String from ...types import AbstractType, ObjectType, Schema, String
from ..connection import Connection from ..connection import Connection
from ..node import Node from ..node import Node, GlobalID
class SharedNodeFields(AbstractType): class SharedNodeFields(AbstractType):
@ -27,6 +27,18 @@ class MyNode(ObjectType):
return MyNode(name=str(id)) return MyNode(name=str(id))
class MyNodeImplementedId(ObjectType):
class Meta:
interfaces = (Node, )
id = GlobalID()
name = String()
@staticmethod
def get_node(id, *_):
return MyNodeImplementedId(name=str(id) + '!')
class MyOtherNode(SharedNodeFields, ObjectType): class MyOtherNode(SharedNodeFields, ObjectType):
extra_field = String() extra_field = String()
@ -45,7 +57,7 @@ class RootQuery(ObjectType):
first = String() first = String()
node = Node.Field() node = Node.Field()
schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode]) schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode, MyNodeImplementedId])
def test_node_good(): def test_node_good():
@ -78,6 +90,14 @@ def test_subclassed_node_query():
[('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])}) [('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])})
def test_node_query_implemented_id():
executed = schema.execute(
'{ node(id:"%s") { ... on MyNodeImplementedId { name } } }' % to_global_id("MyNodeImplementedId", 1)
)
assert not executed.errors
assert executed.data == {'node': {'name': '1!'}}
def test_node_query_incorrect_id(): def test_node_query_incorrect_id():
executed = schema.execute( executed = schema.execute(
'{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2"
@ -97,6 +117,11 @@ type MyNode implements Node {
name: String name: String
} }
type MyNodeImplementedId implements Node {
id: ID!
name: String
}
type MyOtherNode implements Node { type MyOtherNode implements Node {
id: ID! id: ID!
shared: String shared: String