Improved relay compatibility

This commit is contained in:
Syrus Akbary 2016-06-09 21:18:42 -07:00
parent b24e9a1051
commit d67b7bc6a1
11 changed files with 312 additions and 145 deletions

View File

@ -1,33 +1,37 @@
import graphene import graphene
from graphene import relay, resolve_only_args from graphene import implements, relay, resolve_only_args
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
schema = graphene.Schema(name='Starwars Relay Schema')
# @implements(relay.Node)
class Ship(relay.Node): class Ship(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
name = graphene.String(description='The name of the ship.') name = graphene.String(description='The name of the ship.')
@classmethod @classmethod
def get_node(cls, id, info): def get_node(cls, id, context, info):
return get_ship(id) return get_ship(id)
class Faction(relay.Node): # @implements(relay.Node)
class Faction(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
ships = relay.ConnectionField( # ships = relay.ConnectionField(
Ship, description='The ships used by the faction.') # Ship, description='The ships used by the faction.')
ships = graphene.List(graphene.String)
@resolve_only_args # @resolve_only_args
def resolve_ships(self, **args): # def resolve_ships(self, **args):
# Transform the instance ship_ids into real instances # # Transform the instance ship_ids into real instances
return [get_ship(ship_id) for ship_id in self.ships] # return [get_ship(ship_id) for ship_id in self.ships]
@classmethod @classmethod
def get_node(cls, id, info): def get_node(cls, id, context, info):
return get_faction(id) return get_faction(id)
@ -41,9 +45,9 @@ class IntroduceShip(relay.ClientIDMutation):
faction = graphene.Field(Faction) faction = graphene.Field(Faction)
@classmethod @classmethod
def mutate_and_get_payload(cls, input, info): def mutate_and_get_payload(cls, input, context, info):
ship_name = input.get('ship_name') ship_name = input.get('shipName')
faction_id = input.get('faction_id') faction_id = input.get('factionId')
ship = create_ship(ship_name, faction_id) ship = create_ship(ship_name, faction_id)
faction = get_faction(faction_id) faction = get_faction(faction_id)
return IntroduceShip(ship=ship, faction=faction) return IntroduceShip(ship=ship, faction=faction)
@ -52,7 +56,7 @@ class IntroduceShip(relay.ClientIDMutation):
class Query(graphene.ObjectType): class Query(graphene.ObjectType):
rebels = graphene.Field(Faction) rebels = graphene.Field(Faction)
empire = graphene.Field(Faction) empire = graphene.Field(Faction)
node = relay.NodeField() node = relay.Node.Field
@resolve_only_args @resolve_only_args
def resolve_rebels(self): def resolve_rebels(self):
@ -64,8 +68,8 @@ class Query(graphene.ObjectType):
class Mutation(graphene.ObjectType): class Mutation(graphene.ObjectType):
introduce_ship = graphene.Field(IntroduceShip) introduce_ship = IntroduceShip.Field
schema.query = Query
schema.mutation = Mutation schema = graphene.Schema(query=Query, mutation=Mutation)

View File

@ -14,14 +14,14 @@ def test_mutations():
} }
faction { faction {
name name
ships { # ships {
edges { # edges {
node { # node {
id # id
name # name
} # }
} # }
} # }
} }
} }
} }
@ -34,42 +34,43 @@ def test_mutations():
}, },
'faction': { 'faction': {
'name': 'Alliance to Restore the Republic', 'name': 'Alliance to Restore the Republic',
'ships': { # 'ships': {
'edges': [{ # 'edges': [{
'node': { # 'node': {
'id': 'U2hpcDox', # 'id': 'U2hpcDox',
'name': 'X-Wing' # 'name': 'X-Wing'
} # }
}, { # }, {
'node': { # 'node': {
'id': 'U2hpcDoy', # 'id': 'U2hpcDoy',
'name': 'Y-Wing' # 'name': 'Y-Wing'
} # }
}, { # }, {
'node': { # 'node': {
'id': 'U2hpcDoz', # 'id': 'U2hpcDoz',
'name': 'A-Wing' # 'name': 'A-Wing'
} # }
}, { # }, {
'node': { # 'node': {
'id': 'U2hpcDo0', # 'id': 'U2hpcDo0',
'name': 'Millenium Falcon' # 'name': 'Millenium Falcon'
} # }
}, { # }, {
'node': { # 'node': {
'id': 'U2hpcDo1', # 'id': 'U2hpcDo1',
'name': 'Home One' # 'name': 'Home One'
} # }
}, { # }, {
'node': { # 'node': {
'id': 'U2hpcDo5', # 'id': 'U2hpcDo5',
'name': 'Peter' # 'name': 'Peter'
} # }
}] # }]
}, # },
} }
} }
} }
result = schema.execute(query) result = schema.execute(query)
# raise result.errors[0].original_error, None, result.errors[0].stack
assert not result.errors assert not result.errors
assert result.data == expected assert result.data == expected

View File

@ -1,2 +1,2 @@
from .node import Node from .node import Node
# from .mutation import ClientIDMutation from .mutation import ClientIDMutation

View File

@ -0,0 +1,62 @@
from functools import partial
import copy
import six
from graphql_relay import mutation_with_client_mutation_id
from ..types.mutation import Mutation, MutationMeta
from ..types.inputobjecttype import GrapheneInputObjectType, InputObjectType
from ..types.objecttype import GrapheneObjectType
from ..types.field import Field, InputField
from ..utils.props import props
class ClientIDMutationMeta(MutationMeta):
_construct_field = False
def get_options(cls, meta):
options = cls.options_class(
meta,
name=None,
abstract=False
)
options.graphql_type = None
options.interfaces = []
return options
def construct(cls, bases, attrs):
if not cls._meta.abstract:
Input = attrs.pop('Input', None)
field_attrs = props(Input) if Input else {}
cls.mutate_and_get_payload = attrs.pop('mutate_and_get_payload', None)
input_local_fields = {f.name: f for f in InputObjectType._extract_local_fields(field_attrs)}
local_fields = cls._extract_local_fields(attrs)
assert cls.mutate_and_get_payload, "{}.mutate_and_get_payload method is required in a ClientIDMutation ObjectType.".format(cls.__name__)
field = mutation_with_client_mutation_id(
name=cls._meta.name or cls.__name__,
input_fields=input_local_fields,
output_fields=cls._fields(bases, attrs, local_fields),
mutate_and_get_payload=cls.mutate_and_get_payload,
input_type_class=partial(GrapheneInputObjectType, graphene_type=cls),
input_field_class=InputField,
output_type_class=partial(GrapheneObjectType, graphene_type=cls),
field_class=Field,
)
cls._meta.graphql_type = field.type
cls._Field = field
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
return constructed
@property
def Field(cls):
field = copy.copy(cls._Field)
field.reset_counter()
return field
class ClientIDMutation(six.with_metaclass(ClientIDMutationMeta, Mutation)):
class Meta:
abstract = True

View File

@ -1,7 +1,7 @@
import copy import copy
from functools import partial from functools import partial
import six import six
from graphql_relay import node_definitions, from_global_id from graphql_relay import node_definitions, from_global_id, to_global_id
from ..types.field import Field from ..types.field import Field
from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMeta from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMeta
@ -14,10 +14,12 @@ class NodeMeta(InterfaceTypeMeta):
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
cls.get_node = attrs.pop('get_node') cls.get_node = attrs.pop('get_node')
cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions( node_interface, node_field = node_definitions(
cls.get_node, cls.get_node,
id_resolver=cls.id_resolver,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls), interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field field_class=Field,
) )
cls._meta.graphql_type = node_interface cls._meta.graphql_type = node_interface
cls._Field = node_field cls._Field = node_field
@ -42,6 +44,14 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
def from_global_id(cls, global_id): def from_global_id(cls, global_id):
return from_global_id(global_id) return from_global_id(global_id)
@classmethod
def to_global_id(cls, type, id):
return to_global_id(type, id)
@classmethod
def id_resolver(cls, root, args, context, info):
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
@classmethod @classmethod
def get_node(cls, global_id, context, info): def get_node(cls, global_id, context, info):
try: try:
@ -61,4 +71,5 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
''' '''
if cls.require_get_node(): if cls.require_get_node():
assert hasattr(object_type, 'get_node'), '{}.get_node method is required by the Node interface.'.format(object_type.__name__) assert hasattr(object_type, 'get_node'), '{}.get_node method is required by the Node interface.'.format(object_type.__name__)
return super(Node, cls).implements(object_type) return super(Node, cls).implements(object_type)

View File

@ -0,0 +1,84 @@
import pytest
from graphql_relay import to_global_id
from ..mutation import ClientIDMutation
from ...types import ObjectType, Schema, implements
from ...types.scalars import String
class SaySomething(ClientIDMutation):
class Input:
what = String()
phrase = String()
@staticmethod
def mutate_and_get_payload(args, context, info):
what = args.get('what')
return SaySomething(phrase=str(what))
class RootQuery(ObjectType):
something = String()
class Mutation(ObjectType):
say = SaySomething.Field
schema = Schema(query=RootQuery, mutation=Mutation)
def test_no_mutate_and_get_payload():
with pytest.raises(AssertionError) as excinfo:
class MyMutation(ClientIDMutation):
pass
assert "MyMutation.mutate_and_get_payload method is required in a ClientIDMutation ObjectType." == str(excinfo.value)
def test_node_good():
graphql_type = SaySomething._meta.graphql_type
fields = graphql_type.get_fields()
assert 'phrase' in fields
assert SaySomething.Field.type == SaySomething._meta.graphql_type
graphql_field = SaySomething.Field
assert 'input' in graphql_field.args
input = graphql_field.args['input']
assert 'clientMutationId' in input.type.of_type.get_fields()
def test_node_query():
executed = schema.execute(
'mutation a { say(input: {what:"hello", clientMutationId:"1"}) { phrase } }'
)
assert not executed.errors
assert executed.data == {'say': {'phrase': 'hello'}}
# def test_node_query_incorrect_id():
# executed = schema.execute(
# '{ node(id:"%s") { ... on MyNode { name } } }' % "something:2"
# )
# assert not executed.errors
# assert executed.data == {'node': None}
# def test_str_schema():
# assert str(schema) == """
# schema {
# query: RootQuery
# }
# type MyNode implements Node {
# id: ID!
# name: String
# }
# interface Node {
# id: ID!
# }
# type RootQuery {
# first: String
# node(id: ID!): Node
# }
# """.lstrip()

View File

@ -11,7 +11,7 @@ from .argument import to_arguments
class AbstractField(object): class AbstractField(object):
@property @property
def name(self): def name(self):
return self._name or to_camel_case(self.attname) return self._name or self.attname and to_camel_case(self.attname)
@name.setter @name.setter
def name(self, name): def name(self, name):
@ -81,7 +81,8 @@ class Field(AbstractField, GraphQLField, OrderedType):
# We try to get the resolver from the interfaces # We try to get the resolver from the interfaces
if not resolver and issubclass(self.parent, ObjectType): if not resolver and issubclass(self.parent, ObjectType):
graphql_type = self.parent._meta.graphql_type graphql_type = self.parent._meta.graphql_type
for interface in graphql_type._provided_interfaces: interfaces = graphql_type._provided_interfaces or []
for interface in interfaces:
if not isinstance(interface, GrapheneInterfaceType): if not isinstance(interface, GrapheneInterfaceType):
continue continue
fields = interface.get_fields() fields = interface.get_fields()

View File

@ -9,17 +9,19 @@ from ..utils.props import props
class MutationMeta(ObjectTypeMeta): class MutationMeta(ObjectTypeMeta):
def construct_field(cls, field_attrs): _construct_field = True
def construct_field(cls, field_args):
resolver = getattr(cls, 'mutate', None) resolver = getattr(cls, 'mutate', None)
assert resolver, 'All mutations must define a mutate method in it' assert resolver, 'All mutations must define a mutate method in it'
return partial(Field, cls, args=field_attrs, resolver=resolver) return partial(Field, cls, args=field_args, resolver=resolver)
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
super(MutationMeta, cls).construct(bases, attrs) super(MutationMeta, cls).construct(bases, attrs)
if not cls._meta.abstract: if not cls._meta.abstract and cls._construct_field:
Input = attrs.pop('Input', None) Input = attrs.pop('Input', None)
field_attrs = props(Input) if Input else {} field_args = props(Input) if Input else {}
cls.Field = cls.construct_field(field_attrs) cls.Field = cls.construct_field(field_args)
return cls return cls

View File

@ -14,6 +14,8 @@ class GrapheneObjectType(GrapheneGraphQLType, GraphQLObjectType):
self.check_interfaces() self.check_interfaces()
def check_interfaces(self): def check_interfaces(self):
if not self._provided_interfaces:
return
for interface in self._provided_interfaces: for interface in self._provided_interfaces:
if isinstance(interface, GrapheneInterfaceType): if isinstance(interface, GrapheneInterfaceType):
interface.graphene_type.implements(self.graphene_type) interface.graphene_type.implements(self.graphene_type)

View File

@ -97,108 +97,108 @@ def test_parent_container_get_fields():
assert fields.keys() == ['field1', 'field2'] assert fields.keys() == ['field1', 'field2']
# def test_objecttype_as_container_only_args(): def test_objecttype_as_container_only_args():
# container = Container("1", "2") container = Container("1", "2")
# assert container.field1 == "1" assert container.field1 == "1"
# assert container.field2 == "2" assert container.field2 == "2"
# def test_objecttype_as_container_args_kwargs(): def test_objecttype_as_container_args_kwargs():
# container = Container("1", field2="2") container = Container("1", field2="2")
# assert container.field1 == "1" assert container.field1 == "1"
# assert container.field2 == "2" assert container.field2 == "2"
# def test_objecttype_as_container_few_kwargs(): def test_objecttype_as_container_few_kwargs():
# container = Container(field2="2") container = Container(field2="2")
# assert container.field2 == "2" assert container.field2 == "2"
# def test_objecttype_as_container_all_kwargs(): def test_objecttype_as_container_all_kwargs():
# container = Container(field1="1", field2="2") container = Container(field1="1", field2="2")
# assert container.field1 == "1" assert container.field1 == "1"
# assert container.field2 == "2" assert container.field2 == "2"
# def test_objecttype_as_container_extra_args(): def test_objecttype_as_container_extra_args():
# with pytest.raises(IndexError) as excinfo: with pytest.raises(IndexError) as excinfo:
# Container("1", "2", "3") Container("1", "2", "3")
# assert "Number of args exceeds number of fields" == str(excinfo.value) assert "Number of args exceeds number of fields" == str(excinfo.value)
# def test_objecttype_as_container_invalid_kwargs(): def test_objecttype_as_container_invalid_kwargs():
# with pytest.raises(TypeError) as excinfo: with pytest.raises(TypeError) as excinfo:
# Container(unexisting_field="3") Container(unexisting_field="3")
# assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value) assert "'unexisting_field' is an invalid keyword argument for this function" == str(excinfo.value)
# def test_objecttype_reuse_graphql_type(): def test_objecttype_reuse_graphql_type():
# MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={
# 'field': GraphQLField(GraphQLString) 'field': GraphQLField(GraphQLString)
# }) })
# class GrapheneObjectType(ObjectType): class GrapheneObjectType(ObjectType):
# class Meta: class Meta:
# graphql_type = MyGraphQLType graphql_type = MyGraphQLType
# graphql_type = GrapheneObjectType._meta.graphql_type graphql_type = GrapheneObjectType._meta.graphql_type
# assert graphql_type == MyGraphQLType assert graphql_type == MyGraphQLType
# instance = GrapheneObjectType(field="A") instance = GrapheneObjectType(field="A")
# assert instance.field == "A" assert instance.field == "A"
# def test_objecttype_add_fields_in_reused_graphql_type(): def test_objecttype_add_fields_in_reused_graphql_type():
# MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={ MyGraphQLType = GraphQLObjectType('MyGraphQLType', fields={
# 'field': GraphQLField(GraphQLString) 'field': GraphQLField(GraphQLString)
# }) })
# with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
# class GrapheneObjectType(ObjectType): class GrapheneObjectType(ObjectType):
# field = Field(GraphQLString) field = Field(GraphQLString)
# class Meta: class Meta:
# graphql_type = MyGraphQLType graphql_type = MyGraphQLType
# assert """Field "MyGraphQLType.field" can only be mounted in ObjectType or Interface, received GrapheneObjectType.""" == str(excinfo.value) assert """Can't mount Fields in an ObjectType with a defined graphql_type""" == str(excinfo.value)
# def test_objecttype_graphql_interface(): def test_objecttype_graphql_interface():
# MyInterface = GraphQLInterfaceType('MyInterface', fields={ MyInterface = GraphQLInterfaceType('MyInterface', fields={
# 'field': GraphQLField(GraphQLString) 'field': GraphQLField(GraphQLString)
# }) })
# class GrapheneObjectType(ObjectType): class GrapheneObjectType(ObjectType):
# class Meta: class Meta:
# interfaces = [MyInterface] interfaces = [MyInterface]
# graphql_type = GrapheneObjectType._meta.graphql_type graphql_type = GrapheneObjectType._meta.graphql_type
# assert graphql_type.get_interfaces() == (MyInterface, ) assert graphql_type.get_interfaces() == (MyInterface, )
# # assert graphql_type.is_type_of(MyInterface, None, None) # assert graphql_type.is_type_of(MyInterface, None, None)
# fields = graphql_type.get_fields() fields = graphql_type.get_fields()
# assert 'field' in fields assert 'field' in fields
# def test_objecttype_graphene_interface(): def test_objecttype_graphene_interface():
# class GrapheneInterface(Interface): class GrapheneInterface(Interface):
# name = Field(GraphQLString) name = Field(GraphQLString)
# extended = Field(GraphQLString) extended = Field(GraphQLString)
# class GrapheneObjectType(ObjectType): class GrapheneObjectType(ObjectType):
# class Meta: class Meta:
# interfaces = [GrapheneInterface] interfaces = [GrapheneInterface]
# field = Field(GraphQLString) field = Field(GraphQLString)
# graphql_type = GrapheneObjectType._meta.graphql_type graphql_type = GrapheneObjectType._meta.graphql_type
# assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, ) assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, )
# assert graphql_type.is_type_of(GrapheneObjectType(), None, None) assert graphql_type.is_type_of(GrapheneObjectType(), None, None)
# fields = graphql_type.get_fields() fields = graphql_type.get_fields()
# assert 'field' in fields assert 'field' in fields
# assert 'extended' in fields assert 'extended' in fields
# assert 'name' in fields assert 'name' in fields
# assert fields['field'] > fields['extended'] > fields['name'] assert fields['field'] > fields['extended'] > fields['name']
# def test_objecttype_graphene_interface_extended(): # def test_objecttype_graphene_interface_extended():

View File

@ -31,7 +31,7 @@ def get_base_fields(cls, bases):
if attname in fields: if attname in fields:
continue continue
field = copy.copy(field) field = copy.copy(field)
if isinstance(field, Field): if isinstance(field, (Field, InputField)):
field.parent = cls field.parent = cls
fields.add(attname) fields.add(attname)
_fields.append(field) _fields.append(field)