Improved relay nodes and field copies

This commit is contained in:
Syrus Akbary 2016-06-10 00:23:31 -07:00
parent 522f769cad
commit 51e97510c0
9 changed files with 80 additions and 44 deletions

View File

@ -4,7 +4,7 @@ from graphene import implements, relay, resolve_only_args
from .data import create_ship, get_empire, get_faction, get_rebels, get_ship from .data import create_ship, get_empire, get_faction, get_rebels, get_ship
class Ship(graphene.ObjectType, relay.Node): class Ship(relay.Node, graphene.ObjectType):
'''A ship in the Star Wars saga''' '''A ship in the Star Wars saga'''
name = graphene.String(description='The name of the ship.') name = graphene.String(description='The name of the ship.')
@ -13,7 +13,7 @@ class Ship(graphene.ObjectType, relay.Node):
return get_ship(id) return get_ship(id)
class Faction(graphene.ObjectType, relay.Node): class Faction(relay.Node, graphene.ObjectType):
'''A faction in the Star Wars saga''' '''A faction in the Star Wars saga'''
name = graphene.String(description='The name of the faction.') name = graphene.String(description='The name of the faction.')
# ships = relay.ConnectionField( # ships = relay.ConnectionField(

View File

@ -39,14 +39,9 @@ class ClientIDMutationMeta(MutationMeta):
input_fields=input_local_fields, input_fields=input_local_fields,
output_fields=cls._fields(bases, attrs, local_fields), output_fields=cls._fields(bases, attrs, local_fields),
mutate_and_get_payload=cls.mutate_and_get_payload, mutate_and_get_payload=cls.mutate_and_get_payload,
input_type_class=partial(GrapheneInputObjectType, graphene_type=cls),
input_field_class=InputField,
output_type_class=partial(GrapheneObjectType, graphene_type=cls),
field_class=Field,
) )
cls._meta.graphql_type = field.type cls._meta.graphql_type = field.type
cls.Field = partial(Field.copy_and_extend, field, type=None, _creation_counter=None) cls.Field = partial(Field.copy_and_extend, field, type=field.type, _creation_counter=None)
constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs) constructed = super(ClientIDMutationMeta, cls).construct(bases, attrs)
return constructed return constructed

View File

@ -11,25 +11,30 @@ from ..types.interface import GrapheneInterfaceType, Interface, InterfaceTypeMet
class NodeMeta(ObjectTypeMeta): class NodeMeta(ObjectTypeMeta):
def construct(cls, bases, attrs): def construct(cls, bases, attrs):
if not cls.is_object_type(): is_object_type = cls.is_object_type()
cls.get_node = attrs.pop('get_node') cls = super(NodeMeta, cls).construct(bases, attrs)
if not is_object_type:
get_node_from_global_id = getattr(cls, 'get_node_from_global_id', None)
assert get_node_from_global_id, '{}.get_node method is required by the Node interface.'.format(cls.__name__)
cls.id_resolver = attrs.pop('id_resolver', None) cls.id_resolver = attrs.pop('id_resolver', None)
node_interface, node_field = node_definitions( node_interface, node_field = node_definitions(
cls.get_node, get_node_from_global_id,
id_resolver=cls.id_resolver, id_resolver=cls.id_resolver,
interface_class=partial(GrapheneInterfaceType, graphene_type=cls),
field_class=Field,
) )
cls._meta.graphql_type = node_interface cls._meta.graphql_type = node_interface
cls.Field = partial(Field.copy_and_extend, node_field, type=None, _creation_counter=None) cls.Field = partial(Field.copy_and_extend, node_field, type=node_field.type, _creation_counter=None)
return super(NodeMeta, cls).construct(bases, attrs) else:
# The interface provided by node_definitions is not an instance
# of GrapheneInterfaceType, so it will have no graphql_type,
# so will not trigger Node.implements
cls.implements(cls)
return cls
class Node(six.with_metaclass(NodeMeta, Interface)): class Node(six.with_metaclass(NodeMeta, Interface)):
@classmethod @classmethod
def require_get_node(cls): def require_get_node(cls):
return cls == Node return Node._meta.graphql_type in cls._meta.graphql_type._provided_interfaces
@classmethod @classmethod
def from_global_id(cls, global_id): def from_global_id(cls, global_id):
@ -44,7 +49,7 @@ class Node(six.with_metaclass(NodeMeta, Interface)):
return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None)) return cls.to_global_id(info.parent_type.name, getattr(root, 'id', None))
@classmethod @classmethod
def get_node(cls, global_id, context, info): def get_node_from_global_id(cls, global_id, context, info):
try: try:
_type, _id = cls.from_global_id(global_id) _type, _id = cls.from_global_id(global_id)
except: except:

View File

@ -7,7 +7,7 @@ from ...types import ObjectType, Schema, implements
from ...types.scalars import String from ...types.scalars import String
class MyNode(ObjectType, Node): class MyNode(Node, ObjectType):
name = String() name = String()
@ -25,9 +25,7 @@ schema = Schema(query=RootQuery, types=[MyNode])
def test_node_no_get_node(): def test_node_no_get_node():
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
class MyNode(ObjectType): class MyNode(Node, ObjectType):
class Meta:
interfaces = [Node]
pass pass
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value) assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)
@ -35,9 +33,8 @@ def test_node_no_get_node():
def test_node_no_get_node_with_meta(): def test_node_no_get_node_with_meta():
with pytest.raises(AssertionError) as excinfo: with pytest.raises(AssertionError) as excinfo:
class MyNode(ObjectType): class MyNode(Node, ObjectType):
class Meta: pass
interfaces = [Node]
assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value) assert "MyNode.get_node method is required by the Node interface." == str(excinfo.value)

View File

@ -7,7 +7,7 @@ from ...types.scalars import String, Int
class CustomNode(Node): class CustomNode(Node):
@staticmethod @staticmethod
def get_node(id, context, info): def get_node_from_global_id(id, context, info):
assert info.schema == schema assert info.schema == schema
if id in user_data: if id in user_data:
return user_data.get(id) return user_data.get(id)
@ -15,11 +15,11 @@ class CustomNode(Node):
return photo_data.get(id) return photo_data.get(id)
class User(ObjectType, CustomNode): class User(CustomNode, ObjectType):
name = String() name = String()
class Photo(ObjectType, CustomNode): class Photo(CustomNode, ObjectType):
width = Int() width = Int()

View File

@ -116,20 +116,40 @@ class Field(AbstractField, GraphQLField, OrderedType):
@classmethod @classmethod
def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args): def copy_and_extend(cls, field, type=None, args=None, resolver=None, source=None, deprecation_reason=None, name=None, description=None, required=False, _creation_counter=False, **extra_args):
if isinstance(field, Field):
type = type or field._type
resolver = resolver or field._resolver
source = source or field.source
name = name or field._name
required = required or field.required
_creation_counter = field.creation_counter if _creation_counter is False else None
attname = field.attname
parent = field.parent
else:
# If is a GraphQLField
type = type or field.type
resolver = resolver or field.resolver
source = None
name = field.name
required = None
_creation_counter = None
attname = name
parent = None
new_field = cls( new_field = cls(
type=type or field._type, type=type,
args=to_arguments(args, field.args), args=to_arguments(args, field.args),
resolver=field._resolver, resolver=resolver,
source=source or getattr(field, 'source', None), source=source,
deprecation_reason=field.deprecation_reason, deprecation_reason=field.deprecation_reason,
name=field._name, name=name,
required=required or getattr(field, 'required', False), required=required,
description=field.description, description=field.description,
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None, _creation_counter=_creation_counter,
**extra_args **extra_args
) )
new_field.attname = field.attname new_field.attname = attname
new_field.parent = field.parent new_field.parent = parent
return new_field return new_field
def __str__(self): def __str__(self):
@ -168,14 +188,30 @@ class InputField(AbstractField, GraphQLInputObjectField, OrderedType):
@classmethod @classmethod
def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False): def copy_and_extend(cls, field, type=None, default_value=None, description=None, name=None, required=False, _creation_counter=False):
if isinstance(field, Field):
type = type or field._type
name = name or field._name
required = required or field.required
_creation_counter = field.creation_counter if _creation_counter is False else None
attname = field.attname
parent = field.parent
else:
# If is a GraphQLField
type = type or field.type
name = field.name
required = None
_creation_counter = None
attname = None
parent = None
new_field = cls( new_field = cls(
type=type or field._type, type=type,
name=name or field._name, name=name,
required=required or field.required, required=required,
default_value=default_value or field.default_value, default_value=default_value or field.default_value,
description=description or field.description, description=description or field.description,
_creation_counter=getattr(field, 'creation_counter', None) if _creation_counter is False else None, _creation_counter=_creation_counter,
) )
new_field.attname = field.attname new_field.attname = attname
new_field.parent = field.parent new_field.parent = parent
return new_field return new_field

View File

@ -40,7 +40,10 @@ class Interface(six.with_metaclass(InterfaceTypeMeta)):
abstract = True abstract = True
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
from .objecttype import ObjectType
if not isinstance(self, ObjectType):
raise Exception("An interface cannot be intitialized") raise Exception("An interface cannot be intitialized")
super(Interface, self).__init__(*args, **kwargs)
@classmethod @classmethod
def implements(cls, object_type): def implements(cls, object_type):

View File

@ -71,6 +71,7 @@ class ObjectTypeMeta(InterfaceTypeMeta):
cls._meta.interfaces, cls._meta.interfaces,
cls.get_interfaces(bases), cls.get_interfaces(bases),
))) )))
local_fields = cls._extract_local_fields(attrs) local_fields = cls._extract_local_fields(attrs)
if not cls._meta.graphql_type: if not cls._meta.graphql_type:
cls = super(ObjectTypeMeta, cls).construct(bases, attrs) cls = super(ObjectTypeMeta, cls).construct(bases, attrs)

View File

@ -93,7 +93,6 @@ def test_objecttype_as_container_get_fields():
def test_parent_container_get_fields(): def test_parent_container_get_fields():
fields = Container._meta.graphql_type.get_fields() fields = Container._meta.graphql_type.get_fields()
print [(f.creation_counter, f.name) for f in fields.values()]
assert fields.keys() == ['field1', 'field2'] assert fields.keys() == ['field1', 'field2']