Improved extend interfaces syntax

This commit is contained in:
Syrus Akbary 2016-06-09 21:47:06 -07:00
parent d67b7bc6a1
commit d8d884c9be
8 changed files with 72 additions and 61 deletions

View File

@ -21,17 +21,11 @@ class Character(graphene.Interface):
return [get_character(f) for f in self.friends]
# @graphene.implements(Character)
class Human(graphene.ObjectType):
class Meta:
interfaces = [Character]
class Human(graphene.ObjectType, Character):
home_planet = graphene.String()
# @graphene.implements(Character)
class Droid(graphene.ObjectType):
class Meta:
interfaces = [Character]
class Droid(graphene.ObjectType, Character):
primary_function = graphene.String()

View File

@ -4,10 +4,7 @@ from graphene import implements, relay, resolve_only_args
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
# @implements(relay.Node)
class Ship(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
class Ship(graphene.ObjectType, relay.Node):
'''A ship in the Star Wars saga'''
name = graphene.String(description='The name of the ship.')
@ -16,10 +13,7 @@ class Ship(graphene.ObjectType):
return get_ship(id)
# @implements(relay.Node)
class Faction(graphene.ObjectType):
class Meta:
interfaces = [relay.Node]
class Faction(graphene.ObjectType, relay.Node):
'''A faction in the Star Wars saga'''
name = graphene.String(description='The name of the faction.')
# ships = relay.ConnectionField(

View File

@ -4,25 +4,24 @@ import six
from graphql_relay import node_definitions, from_global_id, to_global_id
from ..types.field import Field
from ..types.objecttype import ObjectTypeMeta
from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMeta
class NodeMeta(InterfaceTypeMeta):
def construct_graphql_type(cls, bases):
pass
class NodeMeta(ObjectTypeMeta):
def construct(cls, bases, attrs):
cls.get_node = attrs.pop('get_node')
cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions(
cls.get_node,
id_resolver=cls.id_resolver,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field,
)
cls._meta.graphql_type = node_interface
cls._Field = node_field
if not cls.is_object_type():
cls.get_node = attrs.pop('get_node')
cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions(
cls.get_node,
id_resolver=cls.id_resolver,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field,
)
cls._meta.graphql_type = node_interface
cls._Field = node_field
return super(NodeMeta, cls).construct(bases, attrs)
@property

View File

@ -7,9 +7,7 @@ from ...types import ObjectType, Schema, implements
from ...types.scalars import String
class MyNode(ObjectType):
class Meta:
interfaces = [Node]
class MyNode(ObjectType, Node):
name = String()

View File

@ -15,17 +15,11 @@ class CustomNode(Node):
return photo_data.get(id)
# @implements(CustomNode)
class User(ObjectType):
class Meta:
interfaces = [CustomNode]
class User(ObjectType, CustomNode):
name = String()
# @implements(CustomNode)
class Photo(ObjectType):
class Meta:
interfaces = [CustomNode]
class Photo(ObjectType, CustomNode):
width = Int()

View File

@ -69,34 +69,36 @@ class Field(AbstractField, GraphQLField, OrderedType):
self.attname = attname
self.parent = parent
def default_resolver(self, root, args, context, info):
return getattr(root, self.source or self.attname, None)
@property
def resolver(self):
from .objecttype import ObjectType
from .interface import GrapheneInterfaceType
def default_resolver(root, args, context, info):
return getattr(root, self.source or self.attname, None)
resolver = getattr(self.parent, 'resolve_{}'.format(self.attname), None)
# We try to get the resolver from the interfaces
if not resolver and issubclass(self.parent, ObjectType):
graphql_type = self.parent._meta.graphql_type
interfaces = graphql_type._provided_interfaces or []
for interface in interfaces:
if not isinstance(interface, GrapheneInterfaceType):
continue
fields = interface.get_fields()
if self.attname in fields:
resolver = getattr(interface.graphene_type, 'resolve_{}'.format(self.attname), None)
if resolver:
# We remove the bounding to the method
resolver = resolver #.__func__
break
# This is not needed anymore as Interfaces could be extended now with Python syntax
# if not resolver and issubclass(self.parent, ObjectType):
# graphql_type = self.parent._meta.graphql_type
# interfaces = graphql_type._provided_interfaces or []
# for interface in interfaces:
# if not isinstance(interface, GrapheneInterfaceType):
# continue
# fields = interface.get_fields()
# if self.attname in fields:
# resolver = getattr(interface.graphene_type, 'resolve_{}'.format(self.attname), None)
# if resolver:
# # We remove the bounding to the method
# resolver = resolver #.__func__
# break
if resolver:
resolver = resolver.__func__
else:
resolver = default_resolver
resolver = self.default_resolver
# def resolver_wrapper(root, *args, **kwargs):
# if not isinstance(root, self.parent):

View File

@ -1,10 +1,11 @@
from itertools import chain
import copy
import six
from graphql import GraphQLObjectType
from .definitions import FieldsMeta, ClassTypeMeta, GrapheneGraphQLType
from .interface import GrapheneInterfaceType
from .interface import GrapheneInterfaceType, InterfaceTypeMeta, Interface
class GrapheneObjectType(GrapheneGraphQLType, GraphQLObjectType):
@ -45,7 +46,7 @@ def get_interfaces(cls, interfaces):
yield graphql_type
class ObjectTypeMeta(FieldsMeta, ClassTypeMeta):
class ObjectTypeMeta(InterfaceTypeMeta):
def get_options(cls, meta):
return cls.options_class(
@ -57,9 +58,19 @@ class ObjectTypeMeta(FieldsMeta, ClassTypeMeta):
abstract=False
)
def get_interfaces(cls, bases):
return (b for b in bases if issubclass(b, Interface))
def is_object_type(cls):
return issubclass(cls, ObjectType)
def construct(cls, bases, attrs):
if not cls._meta.abstract:
interfaces = tuple(get_interfaces(cls, cls._meta.interfaces))
if not cls._meta.abstract and cls.is_object_type():
cls.get_interfaces(bases)
interfaces = tuple(get_interfaces(cls, chain(
cls._meta.interfaces,
cls.get_interfaces(bases),
)))
local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type:
cls = super(ObjectTypeMeta, cls).construct(bases, attrs)

View File

@ -201,6 +201,25 @@ def test_objecttype_graphene_interface():
assert fields['field'] > fields['extended'] > fields['name']
def test_objecttype_graphene_inherit_interface():
class GrapheneInterface(Interface):
name = Field(GraphQLString)
extended = Field(GraphQLString)
class GrapheneObjectType(ObjectType, GrapheneInterface):
field = Field(GraphQLString)
graphql_type = GrapheneObjectType._meta.graphql_type
assert graphql_type.get_interfaces() == (GrapheneInterface._meta.graphql_type, )
assert graphql_type.is_type_of(GrapheneObjectType(), None, None)
fields = graphql_type.get_fields()
assert 'field' in fields
assert 'extended' in fields
assert 'name' in fields
assert issubclass(GrapheneObjectType, GrapheneInterface)
assert fields['field'] > fields['extended'] > fields['name']
# def test_objecttype_graphene_interface_extended():
# class GrapheneInterface(Interface):
# field = Field(GraphQLString)