diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py index 8db462d6..2fa6e187 100644 --- a/graphene/relay/mutation.py +++ b/graphene/relay/mutation.py @@ -21,8 +21,8 @@ class ClientIDMutationMeta(ObjectTypeMeta): input_class = attrs.pop('Input', None) base_name = re.sub('Payload$', '', name) - if 'client_mutation_id' not in attrs: - attrs['client_mutation_id'] = String(name='clientMutationId') + default_client_mutation_id = String(name='clientMutationId') + attrs['client_mutation_id'] = attrs.get('client_mutation_id', default_client_mutation_id) cls = ObjectTypeMeta.__new__(cls, '{}Payload'.format(base_name), bases, attrs) mutate_and_get_payload = getattr(cls, 'mutate_and_get_payload', None) if cls.mutate and cls.mutate.__func__ == ClientIDMutation.mutate.__func__: @@ -38,7 +38,7 @@ class ClientIDMutationMeta(ObjectTypeMeta): input_attrs = props(input_class) else: bases += (input_class, ) - input_attrs['client_mutation_id'] = String(name='clientMutationId') + input_attrs['client_mutation_id'] = default_client_mutation_id cls.Input = type('{}Input'.format(base_name), bases + (InputObjectType,), input_attrs) cls.Field = partial(Field, cls, resolver=cls.mutate, input=Argument(cls.Input, required=True)) return cls diff --git a/graphene/relay/node.py b/graphene/relay/node.py index c1bbe6d5..4c0258cb 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -35,24 +35,26 @@ def get_default_connection(cls): class GlobalID(Field): - def __init__(self, node, *args, **kwargs): - super(GlobalID, self).__init__(ID, *args, **kwargs) - self.node = node + def __init__(self, node=None, parent_type=None, required=True, *args, **kwargs): + super(GlobalID, self).__init__(ID, required=required, *args, **kwargs) + self._node = node or Node + self._parent_type_name = parent_type._meta.name if parent_type else None @staticmethod - def id_resolver(parent_resolver, node, root, args, context, info): - id = parent_resolver(root, args, context, info) - return node.to_global_id(info.parent_type.name, id) # root._meta.name + def id_resolver(parent_resolver, node, root, args, context, info, parent_type_name=None): + type_id = parent_resolver(root, args, context, info) + parent_type_name = parent_type_name or info.parent_type.name # root._meta.name + return node.to_global_id(parent_type_name, type_id) def get_resolver(self, parent_resolver): - return partial(self.id_resolver, parent_resolver, self.node) + return partial(self.id_resolver, parent_resolver, self._node, parent_type_name=self._parent_type_name) class NodeMeta(InterfaceMeta): def __new__(cls, name, bases, attrs): cls = InterfaceMeta.__new__(cls, name, bases, attrs) - cls._meta.fields['id'] = GlobalID(cls, required=True, description='The ID of the object.') + cls._meta.fields['id'] = GlobalID(cls, description='The ID of the object.') return cls diff --git a/graphene/relay/tests/test_mutation.py b/graphene/relay/tests/test_mutation.py index 34fbb936..4c14c684 100644 --- a/graphene/relay/tests/test_mutation.py +++ b/graphene/relay/tests/test_mutation.py @@ -1,10 +1,12 @@ import pytest +from graphql_relay import to_global_id + from ...types import (AbstractType, Argument, Field, InputField, InputObjectType, NonNull, ObjectType, Schema) from ...types.scalars import String from ..mutation import ClientIDMutation -from ..node import Node +from ..node import GlobalID, Node class SharedFields(AbstractType): @@ -23,12 +25,14 @@ class SaySomething(ClientIDMutation): class Input: what = String() + phrase = String() + my_node_id = GlobalID(parent_type=MyNode) @staticmethod def mutate_and_get_payload(args, context, info): what = args.get('what') - return SaySomething(phrase=str(what)) + return SaySomething(phrase=str(what), my_node_id=1) class OtherMutation(ClientIDMutation): @@ -71,8 +75,9 @@ def test_no_mutate_and_get_payload(): def test_mutation(): fields = SaySomething._meta.fields - assert list(fields.keys()) == ['phrase', 'client_mutation_id'] + assert list(fields.keys()) == ['phrase', 'my_node_id', 'client_mutation_id'] assert isinstance(fields['phrase'], Field) + assert isinstance(fields['my_node_id'], GlobalID) field = SaySomething.Field() assert field.type == SaySomething assert list(field.args.keys()) == ['input'] @@ -120,12 +125,13 @@ def test_subclassed_mutation_input(): assert fields['client_mutation_id'].type == String -# def test_node_query(): -# executed = schema.execute( -# 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }' -# ) -# assert not executed.errors -# assert executed.data == {'say': {'phrase': 'hello'}} +def test_node_query(): + executed = schema.execute( + 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase, clientMutationId, myNodeId} }' + ) + assert not executed.errors + assert dict(executed.data) == {'say': {'myNodeId': to_global_id('MyNode', '1'), 'clientMutationId': '1', 'phrase': 'hello'}} + def test_edge_query(): executed = schema.execute( diff --git a/graphene/relay/tests/test_node.py b/graphene/relay/tests/test_node.py index 12c7c2c2..70448ebb 100644 --- a/graphene/relay/tests/test_node.py +++ b/graphene/relay/tests/test_node.py @@ -4,7 +4,7 @@ from graphql_relay import to_global_id from ...types import AbstractType, ObjectType, Schema, String from ..connection import Connection -from ..node import Node +from ..node import Node, GlobalID class SharedNodeFields(AbstractType): @@ -27,6 +27,18 @@ class MyNode(ObjectType): return MyNode(name=str(id)) +class MyNodeImplementedId(ObjectType): + + class Meta: + interfaces = (Node, ) + id = GlobalID() + name = String() + + @staticmethod + def get_node(id, *_): + return MyNodeImplementedId(name=str(id) + '!') + + class MyOtherNode(SharedNodeFields, ObjectType): extra_field = String() @@ -45,7 +57,7 @@ class RootQuery(ObjectType): first = String() node = Node.Field() -schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode]) +schema = Schema(query=RootQuery, types=[MyNode, MyOtherNode, MyNodeImplementedId]) def test_node_good(): @@ -78,6 +90,14 @@ def test_subclassed_node_query(): [('shared', '1'), ('extraField', 'extra field info.'), ('somethingElse', '----')])}) +def test_node_query_implemented_id(): + executed = schema.execute( + '{ node(id:"%s") { ... on MyNodeImplementedId { name } } }' % to_global_id("MyNodeImplementedId", 1) + ) + assert not executed.errors + assert executed.data == {'node': {'name': '1!'}} + + def test_node_query_incorrect_id(): executed = schema.execute( '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" @@ -97,6 +117,11 @@ type MyNode implements Node { name: String } +type MyNodeImplementedId implements Node { + id: ID! + name: String +} + type MyOtherNode implements Node { id: ID! shared: String