From d67b7bc6a141174ee63bf35f2cd054acfd70c2e9 Mon Sep 17 00:00:00 2001 From: Syrus Akbary Date: Thu, 9 Jun 2016 21:18:42 -0700 Subject: [PATCH] Improved relay compatibility --- examples/starwars_relay/schema.py | 46 +++--- .../starwars_relay/tests/test_mutation.py | 83 +++++----- graphene/relay/__init__.py | 2 +- graphene/relay/mutation.py | 62 ++++++++ graphene/relay/node.py | 15 +- graphene/relay/tests/test_mutation.py | 84 ++++++++++ graphene/types/field.py | 5 +- graphene/types/mutation.py | 12 +- graphene/types/objecttype.py | 2 + graphene/types/tests/test_objecttype.py | 144 +++++++++--------- graphene/utils/extract_fields.py | 2 +- 11 files changed, 312 insertions(+), 145 deletions(-) create mode 100644 graphene/relay/mutation.py create mode 100644 graphene/relay/tests/test_mutation.py diff --git a/examples/starwars_relay/schema.py b/examples/starwars_relay/schema.py index f2c33834..bd292d09 100644 --- a/examples/starwars_relay/schema.py +++ b/examples/starwars_relay/schema.py @@ -1,33 +1,37 @@ import graphene -from graphene import relay, resolve_only_args +from graphene import implements, relay, resolve_only_args from .data import create_ship, get_empire, get_faction, get_rebels, get_ship -schema = graphene.Schema(name='Starwars Relay Schema') - -class Ship(relay.Node): +# @implements(relay.Node) +class Ship(graphene.ObjectType): + class Meta: + interfaces = [relay.Node] '''A ship in the Star Wars saga''' name = graphene.String(description='The name of the ship.') @classmethod - def get_node(cls, id, info): + def get_node(cls, id, context, info): return get_ship(id) -class Faction(relay.Node): +# @implements(relay.Node) +class Faction(graphene.ObjectType): + class Meta: + interfaces = [relay.Node] '''A faction in the Star Wars saga''' name = graphene.String(description='The name of the faction.') - ships = relay.ConnectionField( - Ship, description='The ships used by the faction.') - - @resolve_only_args - def resolve_ships(self, **args): - # Transform the instance ship_ids into real instances - return [get_ship(ship_id) for ship_id in self.ships] + # ships = relay.ConnectionField( + # Ship, description='The ships used by the faction.') + ships = graphene.List(graphene.String) + # @resolve_only_args + # def resolve_ships(self, **args): + # # Transform the instance ship_ids into real instances + # return [get_ship(ship_id) for ship_id in self.ships] @classmethod - def get_node(cls, id, info): + def get_node(cls, id, context, info): return get_faction(id) @@ -41,9 +45,9 @@ class IntroduceShip(relay.ClientIDMutation): faction = graphene.Field(Faction) @classmethod - def mutate_and_get_payload(cls, input, info): - ship_name = input.get('ship_name') - faction_id = input.get('faction_id') + def mutate_and_get_payload(cls, input, context, info): + ship_name = input.get('shipName') + faction_id = input.get('factionId') ship = create_ship(ship_name, faction_id) faction = get_faction(faction_id) return IntroduceShip(ship=ship, faction=faction) @@ -52,7 +56,7 @@ class IntroduceShip(relay.ClientIDMutation): class Query(graphene.ObjectType): rebels = graphene.Field(Faction) empire = graphene.Field(Faction) - node = relay.NodeField() + node = relay.Node.Field @resolve_only_args def resolve_rebels(self): @@ -64,8 +68,8 @@ class Query(graphene.ObjectType): class Mutation(graphene.ObjectType): - introduce_ship = graphene.Field(IntroduceShip) + introduce_ship = IntroduceShip.Field -schema.query = Query -schema.mutation = Mutation + +schema = graphene.Schema(query=Query, mutation=Mutation) diff --git a/examples/starwars_relay/tests/test_mutation.py b/examples/starwars_relay/tests/test_mutation.py index 762d4b8b..c5d63139 100644 --- a/examples/starwars_relay/tests/test_mutation.py +++ b/examples/starwars_relay/tests/test_mutation.py @@ -14,14 +14,14 @@ def test_mutations(): } faction { name - ships { - edges { - node { - id - name - } - } - } + # ships { + # edges { + # node { + # id + # name + # } + # } + # } } } } @@ -34,42 +34,43 @@ def test_mutations(): }, 'faction': { 'name': 'Alliance to Restore the Republic', - 'ships': { - 'edges': [{ - 'node': { - 'id': 'U2hpcDox', - 'name': 'X-Wing' - } - }, { - 'node': { - 'id': 'U2hpcDoy', - 'name': 'Y-Wing' - } - }, { - 'node': { - 'id': 'U2hpcDoz', - 'name': 'A-Wing' - } - }, { - 'node': { - 'id': 'U2hpcDo0', - 'name': 'Millenium Falcon' - } - }, { - 'node': { - 'id': 'U2hpcDo1', - 'name': 'Home One' - } - }, { - 'node': { - 'id': 'U2hpcDo5', - 'name': 'Peter' - } - }] - }, + # 'ships': { + # 'edges': [{ + # 'node': { + # 'id': 'U2hpcDox', + # 'name': 'X-Wing' + # } + # }, { + # 'node': { + # 'id': 'U2hpcDoy', + # 'name': 'Y-Wing' + # } + # }, { + # 'node': { + # 'id': 'U2hpcDoz', + # 'name': 'A-Wing' + # } + # }, { + # 'node': { + # 'id': 'U2hpcDo0', + # 'name': 'Millenium Falcon' + # } + # }, { + # 'node': { + # 'id': 'U2hpcDo1', + # 'name': 'Home One' + # } + # }, { + # 'node': { + # 'id': 'U2hpcDo5', + # 'name': 'Peter' + # } + # }] + # }, } } } result = schema.execute(query) + # raise result.errors[0].original_error, None, result.errors[0].stack assert not result.errors assert result.data == expected diff --git a/graphene/relay/__init__.py b/graphene/relay/__init__.py index 136d4eb9..afc6bbb0 100644 --- a/graphene/relay/__init__.py +++ b/graphene/relay/__init__.py @@ -1,2 +1,2 @@ from .node import Node -# from .mutation import ClientIDMutation +from .mutation import ClientIDMutation diff --git a/graphene/relay/mutation.py b/graphene/relay/mutation.py new file mode 100644 index 00000000..3f33a5e8 --- /dev/null +++ b/graphene/relay/mutation.py @@ -0,0 +1,62 @@ +from functools import partial +import copy +import six +from graphql_relay import mutation_with_client_mutation_id + +from ..types.mutation import Mutation, MutationMeta +from ..types.inputobjecttype import GrapheneInputObjectType, InputObjectType +from ..types.objecttype import GrapheneObjectType +from ..types.field import Field, InputField + +from ..utils.props import props + + +class ClientIDMutationMeta(MutationMeta): + _construct_field = False + + def get_options(cls, meta): + options = cls.options_class( + meta, + name=None, + abstract=False + ) + options.graphql_type = None + options.interfaces = [] + return options + + def construct(cls, bases, attrs): + if not cls._meta.abstract: + Input = attrs.pop('Input', None) + field_attrs = props(Input) if Input else {} + + cls.mutate_and_get_payload = attrs.pop('mutate_and_get_payload', None) + + input_local_fields = {f.name: f for f in InputObjectType._extract_local_fields(field_attrs)} + local_fields = cls._extract_local_fields(attrs) + assert cls.mutate_and_get_payload, "{}.mutate_and_get_payload method is required in a ClientIDMutation ObjectType.".format(cls.__name__) + field = mutation_with_client_mutation_id( + name=cls._meta.name or cls.__name__, + input_fields=input_local_fields, + output_fields=cls._fields(bases, attrs, local_fields), + mutate_and_get_payload=cls.mutate_and_get_payload, + + input_type_class=partial(GrapheneInputObjectType, graphene_type=cls), + input_field_class=InputField, + output_type_class=partial(GrapheneObjectType, graphene_type=cls), + field_class=Field, + ) + cls._meta.graphql_type = field.type + cls._Field = field + constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs) + return constructed + + @property + def Field(cls): + field = copy.copy(cls._Field) + field.reset_counter() + return field + + +class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)): + class Meta: + abstract = True diff --git a/graphene/relay/node.py b/graphene/relay/node.py index ea3d9e5f..4045af18 100644 --- a/graphene/relay/node.py +++ b/graphene/relay/node.py @@ -1,7 +1,7 @@ import copy from functools import partial import six -from graphql_relay import node_definitions, from_global_id +from graphql_relay import node_definitions, from_global_id, to_global_id from ..types.field import Field from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMeta @@ -14,10 +14,12 @@ class NodeMeta(InterfaceTypeMeta): def construct(cls, bases, attrs): cls.get_node = attrs.pop('get_node') + cls.id_resolver = attrs.pop('id_resolver', None) node_interface, node_field = node_definitions( cls.get_node, + id_resolver=cls.id_resolver, interface_class=partial(GrapheneInterfaceType, graphene_type=cls), - field_class=Field + field_class=Field, ) cls._meta.graphql_type = node_interface cls._Field = node_field @@ -42,6 +44,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)): def from_global_id(cls, global_id): return from_global_id(global_id) + @classmethod + def to_global_id(cls, type, id): + return to_global_id(type, id) + + @classmethod + def id_resolver(cls, root, args, context, info): + return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) + @classmethod def get_node(cls, global_id, context, info): try: @@ -61,4 +71,5 @@ class Node(six.with_metaclass(NodeMeta, Interface)): ''' if cls.require_get_node(): assert hasattr(object_type, 'get_node'), '{}.get_node method is required by the Node interface.'.format(object_type.__name__) + return super(Node, cls).implements(object_type) diff --git a/graphene/relay/tests/test_mutation.py b/graphene/relay/tests/test_mutation.py new file mode 100644 index 00000000..caced509 --- /dev/null +++ b/graphene/relay/tests/test_mutation.py @@ -0,0 +1,84 @@ +import pytest + +from graphql_relay import to_global_id + +from ..mutation import ClientIDMutation +from ...types import ObjectType, Schema, implements +from ...types.scalars import String + + +class SaySomething(ClientIDMutation): + class Input: + what = String() + phrase = String() + + @staticmethod + def mutate_and_get_payload(args, context, info): + what = args.get('what') + return SaySomething(phrase=str(what)) + + +class RootQuery(ObjectType): + something = String() + + +class Mutation(ObjectType): + say = SaySomething.Field + +schema = Schema(query=RootQuery, mutation=Mutation) + + +def test_no_mutate_and_get_payload(): + with pytest.raises(AssertionError) as excinfo: + class MyMutation(ClientIDMutation): + pass + + assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation ObjectType." == str(excinfo.value) + + +def test_node_good(): + graphql_type = SaySomething._meta.graphql_type + fields = graphql_type.get_fields() + assert 'phrase' in fields + assert SaySomething.Field.type == SaySomething._meta.graphql_type + graphql_field = SaySomething.Field + assert 'input' in graphql_field.args + input = graphql_field.args['input'] + assert 'clientMutationId' in input.type.of_type.get_fields() + + +def test_node_query(): + executed = schema.execute( + 'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }' + ) + assert not executed.errors + assert executed.data == {'say': {'phrase': 'hello'}} + + +# def test_node_query_incorrect_id(): +# executed = schema.execute( +# '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2" +# ) +# assert not executed.errors +# assert executed.data == {'node': None} + +# def test_str_schema(): +# assert str(schema) == """ +# schema { +# query: RootQuery +# } + +# type MyNode implements Node { +# id: ID! +# name: String +# } + +# interface Node { +# id: ID! +# } + +# type RootQuery { +# first: String +# node(id: ID!): Node +# } +# """.lstrip() \ No newline at end of file diff --git a/graphene/types/field.py b/graphene/types/field.py index adb63a71..032414e4 100644 --- a/graphene/types/field.py +++ b/graphene/types/field.py @@ -11,7 +11,7 @@ from .argument import to_arguments class AbstractField(object): @property def name(self): - return self._name or to_camel_case(self.attname) + return self._name or self.attname and to_camel_case(self.attname) @name.setter def name(self, name): @@ -81,7 +81,8 @@ class Field(AbstractField, GraphQLField, OrderedType): # We try to get the resolver from the interfaces if not resolver and issubclass(self.parent, ObjectType): graphql_type = self.parent._meta.graphql_type - for interface in graphql_type._provided_interfaces: + interfaces = graphql_type._provided_interfaces or [] + for interface in interfaces: if not isinstance(interface, GrapheneInterfaceType): continue fields = interface.get_fields() diff --git a/graphene/types/mutation.py b/graphene/types/mutation.py index e1e54d04..f87fd570 100644 --- a/graphene/types/mutation.py +++ b/graphene/types/mutation.py @@ -9,17 +9,19 @@ from ..utils.props import props class MutationMeta(ObjectTypeMeta): - def construct_field(cls, field_attrs): + _construct_field = True + + def construct_field(cls, field_args): resolver = getattr(cls, 'mutate', None) assert resolver, 'All mutations must define a mutate method in it' - return partial(Field, cls, args=field_attrs, resolver=resolver) + return partial(Field, cls, args=field_args, resolver=resolver) def construct(cls, bases, attrs): super(MutationMeta, cls).construct(bases, attrs) - if not cls._meta.abstract: + if not cls._meta.abstract and cls._construct_field: Input = attrs.pop('Input', None) - field_attrs = props(Input) if Input else {} - cls.Field = cls.construct_field(field_attrs) + field_args = props(Input) if Input else {} + cls.Field = cls.construct_field(field_args) return cls diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 01bc6e5c..0d96e1f0 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -14,6 +14,8 @@ class GrapheneObjectType(GrapheneGraphQLType, GraphQLObjectType): self.check_interfaces() def check_interfaces(self): + if not self._provided_interfaces: + return for interface in self._provided_interfaces: if isinstance(interface, GrapheneInterfaceType): interface.graphene_type.implements(self.graphene_type) diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index e7ae9d64..d9683581 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -97,108 +97,108 @@ def test_parent_container_get_fields(): assert fields.keys() == ['field1', 'field2'] -# def test_objecttype_as_container_only_args(): -# container = Container("1", "2") -# assert container.field1 == "1" -# assert container.field2 == "2" +def test_objecttype_as_container_only_args(): + container = Container("1", "2") + assert container.field1 == "1" + assert container.field2 == "2" -# def test_objecttype_as_container_args_kwargs(): -# container = Container("1", field2="2") -# assert container.field1 == "1" -# assert container.field2 == "2" +def test_objecttype_as_container_args_kwargs(): + container = Container("1", field2="2") + assert container.field1 == "1" + assert container.field2 == "2" -# def test_objecttype_as_container_few_kwargs(): -# container = Container(field2="2") -# assert container.field2 == "2" +def test_objecttype_as_container_few_kwargs(): + container = Container(field2="2") + assert container.field2 == "2" -# def test_objecttype_as_container_all_kwargs(): -# container = Container(field1="1", field2="2") -# assert container.field1 == "1" -# assert container.field2 == "2" +def test_objecttype_as_container_all_kwargs(): + container = Container(field1="1", field2="2") + assert container.field1 == "1" + assert container.field2 == "2" -# def test_objecttype_as_container_extra_args(): -# with pytest.raises(IndexError) as excinfo: -# Container("1", "2", "3") +def test_objecttype_as_container_extra_args(): + with pytest.raises(IndexError) as excinfo: + Container("1", "2", "3") -# assert "Number of args exceeds number of fields" == str(excinfo.value) + assert "Number of args exceeds number of fields" == str(excinfo.value) -# def test_objecttype_as_container_invalid_kwargs(): -# with pytest.raises(TypeError) as excinfo: -# Container(unexisting_field="3") +def test_objecttype_as_container_invalid_kwargs(): + with pytest.raises(TypeError) as excinfo: + Container(unexisting_field="3") -# assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value) + assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value) -# def test_objecttype_reuse_graphql_type(): -# MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ -# 'field': GraphQLField(GraphQLString) -# }) +def test_objecttype_reuse_graphql_type(): + MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ + 'field': GraphQLField(GraphQLString) + }) -# class GrapheneObjectType(ObjectType): -# class Meta: -# graphql_type = MyGraphQLType + class GrapheneObjectType(ObjectType): + class Meta: + graphql_type = MyGraphQLType -# graphql_type = GrapheneObjectType._meta.graphql_type -# assert graphql_type == MyGraphQLType -# instance = GrapheneObjectType(field="A") -# assert instance.field == "A" + graphql_type = GrapheneObjectType._meta.graphql_type + assert graphql_type == MyGraphQLType + instance = GrapheneObjectType(field="A") + assert instance.field == "A" -# def test_objecttype_add_fields_in_reused_graphql_type(): -# MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ -# 'field': GraphQLField(GraphQLString) -# }) +def test_objecttype_add_fields_in_reused_graphql_type(): + MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ + 'field': GraphQLField(GraphQLString) + }) -# with pytest.raises(AssertionError) as excinfo: -# class GrapheneObjectType(ObjectType): -# field = Field(GraphQLString) + with pytest.raises(AssertionError) as excinfo: + class GrapheneObjectType(ObjectType): + field = Field(GraphQLString) -# class Meta: -# graphql_type = MyGraphQLType + class Meta: + graphql_type = MyGraphQLType -# assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneObjectType.""" == str(excinfo.value) + assert """Can't mount Fields in an ObjectType with a defined graphql_type""" == str(excinfo.value) -# def test_objecttype_graphql_interface(): -# MyInterface = GraphQLInterfaceType('MyInterface', fields={ -# 'field': GraphQLField(GraphQLString) -# }) +def test_objecttype_graphql_interface(): + MyInterface = GraphQLInterfaceType('MyInterface', fields={ + 'field': GraphQLField(GraphQLString) + }) -# class GrapheneObjectType(ObjectType): -# class Meta: -# interfaces = [MyInterface] + class GrapheneObjectType(ObjectType): + class Meta: + interfaces = [MyInterface] -# graphql_type = GrapheneObjectType._meta.graphql_type -# assert graphql_type.get_interfaces() == (MyInterface, ) -# # assert graphql_type.is_type_of(MyInterface, None, None) -# fields = graphql_type.get_fields() -# assert 'field' in fields + graphql_type = GrapheneObjectType._meta.graphql_type + assert graphql_type.get_interfaces() == (MyInterface, ) + # assert graphql_type.is_type_of(MyInterface, None, None) + fields = graphql_type.get_fields() + assert 'field' in fields -# def test_objecttype_graphene_interface(): -# class GrapheneInterface(Interface): -# name = Field(GraphQLString) -# extended = Field(GraphQLString) +def test_objecttype_graphene_interface(): + class GrapheneInterface(Interface): + name = Field(GraphQLString) + extended = Field(GraphQLString) -# class GrapheneObjectType(ObjectType): -# class Meta: -# interfaces = [GrapheneInterface] + class GrapheneObjectType(ObjectType): + class Meta: + interfaces = [GrapheneInterface] -# field = Field(GraphQLString) + field = Field(GraphQLString) -# graphql_type = GrapheneObjectType._meta.graphql_type -# assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) -# assert graphql_type.is_type_of(GrapheneObjectType(), None, None) -# fields = graphql_type.get_fields() -# assert 'field' in fields -# assert 'extended' in fields -# assert 'name' in fields -# assert fields['field'] > fields['extended'] > fields['name'] + graphql_type = GrapheneObjectType._meta.graphql_type + assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) + assert graphql_type.is_type_of(GrapheneObjectType(), None, None) + fields = graphql_type.get_fields() + assert 'field' in fields + assert 'extended' in fields + assert 'name' in fields + assert fields['field'] > fields['extended'] > fields['name'] # def test_objecttype_graphene_interface_extended(): diff --git a/graphene/utils/extract_fields.py b/graphene/utils/extract_fields.py index 5b6727c7..49231e0e 100644 --- a/graphene/utils/extract_fields.py +++ b/graphene/utils/extract_fields.py @@ -31,7 +31,7 @@ def get_base_fields(cls, bases): if attname in fields: continue field = copy.copy(field) - if isinstance(field, Field): + if isinstance(field, (Field, InputField)): field.parent = cls fields.add(attname) _fields.append(field)